/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package main
import (
"context"
"fmt"
"os"
_ "github.com/lib/pq"
"github.com/pkg/errors"
"github.com/raoptimus/db-migrator.go/internal/application/handler"
"github.com/raoptimus/db-migrator.go/internal/domain/validator"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/adapter/urfavecli"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/log"
"github.com/urfave/cli/v3"
)
const maxConnAttempts = 100
var (
Version string
GitCommit string
)
func main() {
options := handler.Options{}
logger := log.Std
var handlers *handler.Handlers
cmd := &cli.Command{
Name: "Database Migration Tool",
Usage: "Database migration tool for different databases",
UsageText: "db-migrator [command] [command options]\n\n Command-line options override environment variables.",
Description: `This tool helps to manage database migrations.
It supports PostgreSQL and other databases that support SQL commands.
The tool can perform up, down, redo, create, history, new, and to operations on the database.
For more information, please refer to the documentation.
More details about the tool can be found at https://github.com/raoptimus/db-migrator.go`,
Version: fmt.Sprintf("%s.rev[%s]", Version, GitCommit),
Before: func(ctx context.Context, command *cli.Command) (context.Context, error) {
handlers = handler.NewHandlers(&options, logger)
return ctx, nil
},
Commands: []*cli.Command{
{
Name: "up",
Action: func(ctx context.Context, c *cli.Command) error {
return urfavecli.Adapt(handlers.Upgrade)(ctx, c)
},
Flags: flags(&options, true),
},
{
Name: "down",
Action: func(ctx context.Context, c *cli.Command) error {
return urfavecli.Adapt(handlers.Downgrade)(ctx, c)
},
Flags: flags(&options, true),
},
{
Name: "redo",
Action: func(ctx context.Context, c *cli.Command) error {
return urfavecli.Adapt(handlers.Redo)(ctx, c)
},
Flags: flags(&options, true),
},
{
Name: "to",
Action: func(ctx context.Context, c *cli.Command) error {
return urfavecli.Adapt(handlers.To)(ctx, c)
},
Flags: flags(&options, true),
},
{
Name: "create",
Action: func(ctx context.Context, c *cli.Command) error {
return urfavecli.Adapt(handlers.Create)(ctx, c)
},
Flags: flags(&options, false),
},
{
Name: "history",
Action: func(ctx context.Context, c *cli.Command) error {
return urfavecli.Adapt(handlers.History)(ctx, c)
},
Flags: flags(&options, true),
},
{
Name: "new",
Action: func(ctx context.Context, c *cli.Command) error {
return urfavecli.Adapt(handlers.HistoryNew)(ctx, c)
},
Flags: flags(&options, true),
},
{
Name: "release",
Usage: "Apply all new migrations atomically in a single transaction",
Action: func(ctx context.Context, c *cli.Command) error {
return urfavecli.Adapt(handlers.Release)(ctx, c)
},
Flags: flags(&options, true),
},
{
Name: "rollback",
Usage: "Revert the latest release batch atomically in a single transaction",
Action: func(ctx context.Context, c *cli.Command) error {
return urfavecli.Adapt(handlers.Rollback)(ctx, c)
},
Flags: flags(&options, true),
},
},
DefaultCommand: "help",
}
if err := cmd.Run(context.Background(), os.Args); err != nil {
logger.Fatal(err)
}
}
func flags(options *handler.Options, dsnIsRequired bool) []cli.Flag {
return []cli.Flag{
&cli.StringFlag{
Name: "placeholderCustom",
Sources: cli.EnvVars("PLACEHOLDER_CUSTOM"),
Aliases: []string{"phc"},
Usage: "PLACEHOLDER_CUSTOM",
Destination: &options.PlaceholderCustom,
Value: "",
Validator: validator.ValidateIdentifier,
},
&cli.StringFlag{
Name: "dsn",
Sources: cli.EnvVars("DSN"),
Aliases: []string{"d"},
Usage: "DB connection string",
Destination: &options.DSN,
Required: dsnIsRequired,
},
&cli.IntFlag{
Name: "maxConnAttempts",
Sources: cli.EnvVars("MAX_CONN_ATTEMPTS"),
Aliases: []string{"ma"},
Usage: "Maximum number of connection attempts",
Destination: &options.MaxConnAttempts,
Value: 1,
Validator: func(i int) error {
return validator.ValidateStringLen(1, maxConnAttempts, i)
},
},
&cli.StringFlag{
Name: "migrationPath",
Sources: cli.EnvVars("MIGRATION_PATH"),
Aliases: []string{"p"},
Value: "./migrations",
Usage: "Directory for migrated files",
Destination: &options.Directory,
Validator: func(s string) error {
fi, err := os.Stat(s)
if err != nil || !fi.IsDir() {
return errors.WithMessage(err, "directory does not exist")
}
return nil
},
},
&cli.StringFlag{
Name: "migrationTable",
Sources: cli.EnvVars("MIGRATION_TABLE"),
Aliases: []string{"t"},
Value: "migration",
Usage: "Table name for history of migrates",
Destination: &options.TableName,
Action: func(ctx context.Context, command *cli.Command, s string) error {
return validator.ValidateIdentifier(s)
},
},
&cli.StringFlag{
Name: "migrationClusterName",
Sources: cli.EnvVars("MIGRATION_CLUSTER_NAME"),
Aliases: []string{"cn"},
Value: "",
Usage: "Cluster name for history of migrates",
Destination: &options.ClusterName,
Action: func(ctx context.Context, command *cli.Command, s string) error {
return validator.ValidateIdentifier(s)
},
},
&cli.BoolFlag{
Name: "migrationReplicated",
Sources: cli.EnvVars("MIGRATION_REPLICATED"),
Aliases: []string{"cr"},
Value: false,
Usage: "Using replicated experimental function to clickhouse for history table of migrates",
Destination: &options.Replicated,
},
&cli.BoolFlag{
Name: "compact",
Sources: cli.EnvVars("COMPACT"),
Aliases: []string{"c"},
Usage: "Indicates whether the console output should be compacted.",
Value: false,
Destination: &options.Compact,
},
&cli.BoolFlag{
Name: "interactive",
Sources: cli.EnvVars("INTERACTIVE"),
Aliases: []string{"i"},
Usage: "Whether to run the command interactively",
Value: true,
Destination: &options.Interactive,
},
&cli.BoolFlag{
Name: "dryRun",
Sources: cli.EnvVars("DRY_RUN"),
Aliases: []string{"dry"},
Usage: "Show SQL that would be executed without actually running it",
Value: false,
Destination: &options.DryRun,
},
}
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package dbmigrator
import (
"github.com/raoptimus/db-migrator.go/internal/infrastructure/dal/connection"
)
// NewConnection creates a new database connection from the provided DSN string.
// Supported drivers: clickhouse, postgres, mysql, tarantool.
//
// DSN format: driver://username:password@host:port/dbname?options
//
// Examples:
// - PostgreSQL: postgres://user:pass@localhost:5432/mydb?sslmode=disable
// - MySQL: mysql://user:pass@localhost:3306/mydb
// - ClickHouse: clickhouse://user:pass@localhost:9000/mydb?compress=true
// - Tarantool: tarantool://user:pass@localhost:3301/mydb
//
//nolint:ireturn // intentionally returns interface to hide internal implementation
func NewConnection(dsn string) (Connection, error) {
return connection.New(dsn)
}
// TryConnection attempts to create and ping a database connection with retries.
// It will retry up to maxAttempts times with a 1-second delay between attempts.
// If maxAttempts is less than 1, it defaults to 1.
//
//nolint:ireturn // intentionally returns interface to hide internal implementation
func TryConnection(dsn string, maxAttempts int) (Connection, error) {
return connection.Try(dsn, maxAttempts)
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package handler
import (
"fmt"
"os"
"regexp"
"github.com/pkg/errors"
"github.com/raoptimus/db-migrator.go/internal/helper/console"
"github.com/raoptimus/db-migrator.go/internal/helper/timex"
)
const fileModeExecutable = 0o755
// ErrInvalidFileName is returned when a migration name contains invalid characters.
var ErrInvalidFileName = errors.New("the migration name should contain letters, digits, underscore and/or backslash characters only")
var regexpFileName = regexp.MustCompile(`^[\w\\]+$`)
// Create handles the creation of new migration files.
type Create struct {
options *Options
file File
logger Logger
fileNameBuilder FileNameBuilder
}
// NewCreate creates a new Create handler instance.
func NewCreate(
options *Options,
logger Logger,
file File,
fileNameBuilder FileNameBuilder,
) *Create {
return &Create{
options: options,
file: file,
logger: logger,
fileNameBuilder: fileNameBuilder,
}
}
// Handle processes the create command to generate new migration files.
func (c *Create) Handle(cmd *Command) error {
if !cmd.Args.Present() {
return errors.WithStack(ErrInvalidFileName)
}
migrationName := cmd.Args.First()
if !regexpFileName.MatchString(migrationName) {
return ErrInvalidFileName
}
prefix := timex.StdTime.Now().Format("060102_150405")
version := prefix + "_" + migrationName
fileNameUp, _ := c.fileNameBuilder.Up(version, true)
fileNameDown, _ := c.fileNameBuilder.Down(version, true)
question := fmt.Sprintf(
"Create new migration files: \n'%s' and \n'%s'?\n",
fileNameUp,
fileNameDown,
)
if c.options.Interactive && !console.Confirm(question) {
return nil
}
if err := c.createDirectory(c.options.Directory); err != nil {
return err
}
if err := c.file.Create(fileNameUp); err != nil {
return err
}
if err := c.file.Create(fileNameDown); err != nil {
return err
}
c.logger.Success("New migration created successfully.")
return nil
}
func (c *Create) createDirectory(path string) error {
if ok, err := c.file.Exists(path); err != nil || ok {
return err
}
if err := os.Mkdir(path, fileModeExecutable); err != nil {
return errors.Wrapf(err, "creating directory %s", path)
}
return nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package handler
import (
"github.com/raoptimus/db-migrator.go/internal/helper/console"
)
// Downgrade handles the reverting of previously applied migrations.
type Downgrade struct {
options *Options
presenter Presenter
fileNameBuilder FileNameBuilder
}
// NewDowngrade creates a new Downgrade handler instance.
func NewDowngrade(options *Options, presenter Presenter, fileNameBuilder FileNameBuilder) *Downgrade {
return &Downgrade{
options: options,
presenter: presenter,
fileNameBuilder: fileNameBuilder,
}
}
// Handle processes the downgrade command to revert applied migrations.
func (d *Downgrade) Handle(cmd *Command, svc MigrationService) error {
limit, err := stepOrDefault(cmd, minLimit)
if err != nil {
return err
}
migrations, err := svc.Migrations(cmd.Context(), limit)
if err != nil {
return err
}
migrationsCount := migrations.Len()
if migrationsCount == 0 {
d.presenter.ShowNoMigrationsToRevert()
return nil
}
d.presenter.ShowDowngradePlan(migrations)
question := d.presenter.AskDowngradeConfirmation(migrationsCount)
if d.options.Interactive && !console.Confirm(question) {
return nil
}
reverted := 0
for i := range migrations {
migration := &migrations[i]
fileName, safely := d.fileNameBuilder.Down(migration.Version, false)
if err := svc.RevertFile(cmd.Context(), migration, fileName, safely); err != nil {
d.presenter.ShowDowngradeError(reverted, migrationsCount)
return err
}
reverted++
}
d.presenter.ShowDowngradeSuccess(migrationsCount)
return nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package handler
import "context"
// Command represents a CLI command with its arguments and execution context.
type Command struct {
Args Args
ctx context.Context
}
// Context returns the command's context.
// If no context is set, it returns context.Background().
func (c *Command) Context() context.Context {
if c.ctx != nil {
return c.ctx
}
return context.Background()
}
// WithContext creates a shallow copy of the command with a new context.
// It panics if the provided context is nil.
func (c *Command) WithContext(ctx context.Context) *Command {
if ctx == nil {
panic("nil context")
}
c2 := new(Command)
*c2 = *c
c2.ctx = ctx
return c2
}
// Args provides access to command-line arguments.
//
//go:generate mockery
type Args interface {
// First returns the first argument, or else a blank string
First() string
// Present checks if there are any arguments present
Present() bool
// Slice returns a copy of the internal slice
Slice() []string
}
// Handler handles CLI commands.
//
//go:generate mockery
type Handler interface {
Handle(cmd *Command) error
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package handler
import (
"github.com/raoptimus/db-migrator.go/internal/application/presenter"
"github.com/raoptimus/db-migrator.go/internal/domain/builder"
iohelp "github.com/raoptimus/db-migrator.go/internal/helper/io"
)
type Handlers struct {
Create Handler
Upgrade Handler
Downgrade Handler
Redo Handler
To Handler
History Handler
HistoryNew Handler
Release Handler
Rollback Handler
}
func NewHandlers(options *Options, logger Logger) *Handlers {
fileNameBuilder := builder.NewFileName(iohelp.StdFile, options.Directory)
migrationPresenter := presenter.NewMigrationPresenter(logger)
return &Handlers{
Create: NewCreate(options, logger, iohelp.StdFile, fileNameBuilder),
Upgrade: NewServiceWrapHandler(options, logger, NewUpgrade(options, migrationPresenter, fileNameBuilder)),
Downgrade: NewServiceWrapHandler(options, logger, NewDowngrade(options, migrationPresenter, fileNameBuilder)),
Redo: NewServiceWrapHandler(options, logger, NewRedo(options, migrationPresenter, fileNameBuilder)),
To: NewServiceWrapHandler(options, logger, NewTo(options, migrationPresenter, fileNameBuilder)),
History: NewServiceWrapHandler(options, logger, NewHistory(options, migrationPresenter)),
HistoryNew: NewServiceWrapHandler(options, logger, NewHistoryNew(options, migrationPresenter)),
Release: NewServiceWrapHandler(options, logger, NewRelease(options, migrationPresenter, fileNameBuilder)),
Rollback: NewServiceWrapHandler(options, logger, NewRollback(options, migrationPresenter, fileNameBuilder)),
}
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package handler
import (
"fmt"
"regexp"
"strconv"
"time"
"github.com/pkg/errors"
"github.com/raoptimus/db-migrator.go/internal/domain/model"
)
const (
all = "all"
minLimit = 1
timestampLength = 13 // YYMMDD_HHMMSS format length
)
// ErrArgumentMustBeGreaterThanZero is returned when a step argument is less than or equal to zero.
var ErrArgumentMustBeGreaterThanZero = errors.New("the step argument must be greater than 0")
// ErrMissingDownFiles is returned when down migration files are missing for rollback.
var ErrMissingDownFiles = errors.New("missing down migration files")
func stepOrDefault(cmd *Command, defaults int) (int, error) {
if !cmd.Args.Present() {
return defaults, nil
}
value := cmd.Args.First()
switch value {
case "":
return defaults, nil
case all:
return 0, nil
default:
i, err := strconv.Atoi(value)
if err != nil {
return -1, fmt.Errorf("the step argument %s is not valid", value)
}
if i < 1 {
return -1, ErrArgumentMustBeGreaterThanZero
}
return i, nil
}
}
// ErrTargetVersionRequired is returned when target version argument is missing.
var ErrTargetVersionRequired = errors.New("target version is required")
// extractTimestamp extracts the timestamp part (YYMMDD_HHMMSS) from a full migration version.
// Examples:
// - "251002_184510" -> "251002_184510"
// - "251002_184510_change_scheme" -> "251002_184510"
func extractTimestamp(version string) string {
if len(version) < timestampLength {
return version
}
// Timestamp format is always timestampLength characters: YYMMDD_HHMMSS
return version[:timestampLength]
}
// parseTargetVersion extracts and normalizes a migration version from various input formats.
// Supported formats:
// - Timestamp: "150101_185401"
// - Full name: "150101_185401_create_news_table"
// - DateTime: "2015-01-01 18:54:01"
// - UNIX timestamp: "1392853618"
//
// Returns the normalized version in YYMMDD_HHMMSS format.
func parseTargetVersion(input string) (string, error) {
if input == "" {
return "", ErrTargetVersionRequired
}
// Case 1: Already in timestamp format (YYMMDD_HHMMSS)
timestampPattern := regexp.MustCompile(`^\d{6}_\d{6}$`)
if timestampPattern.MatchString(input) {
return input, nil
}
// Case 2: Full migration name (YYMMDD_HHMMSS_name)
// Extract first 13 characters: 6 digits + underscore + 6 digits
fullNamePattern := regexp.MustCompile(`^(\d{6}_\d{6})_.+$`)
if matches := fullNamePattern.FindStringSubmatch(input); matches != nil {
return matches[1], nil
}
// Case 3: DateTime string "2015-01-01 18:54:01"
if dt, err := time.Parse("2006-01-02 15:04:05", input); err == nil {
return dt.Format("060102_150405"), nil
}
// Case 4: UNIX timestamp "1392853618" (must be at least 9 digits)
// Validate: minimum timestamp is 946684800 (2000-01-01), which has 10 digits
// But we allow 9+ digits to handle older timestamps
if len(input) >= 9 && len(input) <= 10 {
if timestamp, err := strconv.ParseInt(input, 10, 64); err == nil {
dt := time.Unix(timestamp, 0).UTC()
return dt.Format("060102_150405"), nil
}
}
return "", fmt.Errorf("invalid version format: %s", input)
}
// applyMigrations applies a list of migrations.
// Returns number of applied migrations and error if any occurred.
func applyMigrations(
cmd *Command,
svc MigrationService,
presenter Presenter,
fileNameBuilder FileNameBuilder,
migrations model.Migrations,
) (int, error) {
applied := 0
for i := range migrations {
migration := &migrations[i]
fileName, safely := fileNameBuilder.Up(migration.Version, false)
if err := svc.ApplyFile(cmd.Context(), migration, fileName, safely); err != nil {
presenter.ShowUpgradeError(applied, len(migrations))
return applied, err
}
applied++
}
return applied, nil
}
// revertMigrations reverts a list of migrations.
// Returns number of reverted migrations and error if any occurred.
func revertMigrations(
cmd *Command,
svc MigrationService,
presenter Presenter,
fileNameBuilder FileNameBuilder,
migrations model.Migrations,
) (int, error) {
reverted := 0
for i := range migrations {
migration := &migrations[i]
fileName, safely := fileNameBuilder.Down(migration.Version, false)
if err := svc.RevertFile(cmd.Context(), migration, fileName, safely); err != nil {
presenter.ShowDowngradeError(reverted, len(migrations))
return reverted, err
}
reverted++
}
return reverted, nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package handler
const defaultGetHistoryLimit = 10
// History handles the display of applied migration history.
type History struct {
options *Options
presenter Presenter
}
// NewHistory creates a new History handler instance.
func NewHistory(
options *Options,
presenter Presenter,
) *History {
return &History{
options: options,
presenter: presenter,
}
}
// Handle processes the history command to display applied migrations.
func (h *History) Handle(cmd *Command, svc MigrationService) error {
limit, err := stepOrDefault(cmd, defaultGetHistoryLimit)
if err != nil {
return err
}
migrations, err := svc.Migrations(cmd.Context(), limit)
if err != nil {
return err
}
migrationsCount := migrations.Len()
if migrationsCount == 0 {
h.presenter.ShowNoMigrationsToRevert()
return nil
}
if limit > 0 {
h.presenter.ShowHistoryHeader(migrationsCount)
} else {
h.presenter.ShowAllHistoryHeader(migrationsCount)
}
h.presenter.PrintMigrations(migrations, true)
return nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package handler
// HistoryNew handles the display of pending migrations that have not been applied yet.
type HistoryNew struct {
options *Options
presenter Presenter
}
// NewHistoryNew creates a new HistoryNew handler instance.
func NewHistoryNew(
options *Options,
presenter Presenter,
) *HistoryNew {
return &HistoryNew{
options: options,
presenter: presenter,
}
}
// Handle processes the new command to display pending migrations.
func (h *HistoryNew) Handle(cmd *Command, svc MigrationService) error {
limit, err := stepOrDefault(cmd, defaultGetHistoryLimit)
if err != nil {
return err
}
migrations, err := svc.NewMigrations(cmd.Context())
if err != nil {
return err
}
migrationsCount := migrations.Len()
if migrationsCount == 0 {
h.presenter.ShowNoNewMigrations()
return nil
}
if limit > 0 && migrationsCount > limit {
h.presenter.ShowNewMigrationsLimitedHeader(limit, migrationsCount)
migrations = migrations[:limit]
} else {
h.presenter.ShowNewMigrationsHeader(migrationsCount)
}
h.presenter.PrintMigrations(migrations, false)
return nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package handler
import (
"github.com/pkg/errors"
"github.com/raoptimus/db-migrator.go/internal/domain/validator"
)
const maxConnAttempts = 100
// Options contains configuration parameters for database migration operations.
type Options struct {
PlaceholderCustom string
DSN string
MaxConnAttempts int
Directory string
TableName string
ClusterName string
Replicated bool
Compact bool
Interactive bool
MaxSQLOutputLength int
DryRun bool
}
func (o *Options) Validate() error {
if err := validator.ValidateIdentifier(o.PlaceholderCustom); err != nil {
return errors.WithMessage(err, "placeholderCustom")
}
if err := validator.ValidateStringLen(1, maxConnAttempts, o.MaxConnAttempts); err != nil {
return errors.WithMessage(err, "maxConnAttempts")
}
if err := validator.ValidateIdentifier(o.TableName); err != nil {
return errors.WithMessage(err, "tableName")
}
if err := validator.ValidateIdentifier(o.ClusterName); err != nil {
return errors.WithMessage(err, "clusterName")
}
return nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package handler
import (
"github.com/raoptimus/db-migrator.go/internal/domain/model"
"github.com/raoptimus/db-migrator.go/internal/helper/console"
)
// Redo handles the reverting and reapplying of previously applied migrations.
type Redo struct {
options *Options
presenter Presenter
fileNameBuilder FileNameBuilder
}
// NewRedo creates a new Redo handler instance.
func NewRedo(
options *Options,
presenter Presenter,
fileNameBuilder FileNameBuilder,
) *Redo {
return &Redo{
options: options,
presenter: presenter,
fileNameBuilder: fileNameBuilder,
}
}
// Handle processes the redo command to revert and reapply migrations.
func (r *Redo) Handle(cmd *Command, svc MigrationService) error {
limit, err := stepOrDefault(cmd, minLimit)
if err != nil {
return err
}
migrations, err := svc.Migrations(cmd.Context(), limit)
if err != nil {
return err
}
migrationsCount := migrations.Len()
if migrationsCount == 0 {
r.presenter.ShowNoMigrationsToRevert()
return nil
}
r.presenter.ShowRedoPlan(migrations)
question := r.presenter.AskRedoConfirmation(migrationsCount)
if r.options.Interactive && !console.Confirm(question) {
return nil
}
reversedMigrations := make(model.Migrations, 0, migrationsCount)
for i := range migrations {
migration := &migrations[i]
fileName, safely := r.fileNameBuilder.Down(migration.Version, false)
if err := svc.RevertFile(cmd.Context(), migration, fileName, safely); err != nil {
r.presenter.ShowRedoError()
return err
}
reversedMigrations = append(reversedMigrations, migrations[i])
}
for i := migrationsCount - 1; i >= 0; i-- {
migration := &reversedMigrations[i]
fileName, safely := r.fileNameBuilder.Up(migration.Version, false)
if err := svc.ApplyFile(cmd.Context(), migration, fileName, safely); err != nil {
r.presenter.ShowRedoError()
return err
}
}
r.presenter.ShowRedoSuccess(migrationsCount)
return nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package handler
import (
"context"
"time"
"github.com/raoptimus/db-migrator.go/internal/helper/console"
)
// Release handles the application of all pending migrations atomically in a single transaction.
// All migrations in a release share the same apply_time, enabling batch rollback.
type Release struct {
options *Options
presenter Presenter
fileNameBuilder FileNameBuilder
}
// NewRelease creates a new Release handler instance.
func NewRelease(
options *Options,
presenter Presenter,
fileNameBuilder FileNameBuilder,
) *Release {
return &Release{
options: options,
presenter: presenter,
fileNameBuilder: fileNameBuilder,
}
}
// Handle processes the release command to apply all pending migrations atomically.
func (r *Release) Handle(cmd *Command, svc MigrationService) error {
migrations, err := svc.NewMigrations(cmd.Context())
if err != nil {
return err
}
if migrations.Len() == 0 {
r.presenter.ShowNoNewMigrations()
return nil
}
r.presenter.ShowUpgradePlan(migrations, migrations.Len())
question := r.presenter.AskUpgradeConfirmation(migrations.Len())
if r.options.Interactive && !console.Confirm(question) {
return nil
}
applyTime := time.Now().Unix()
err = svc.ExecInTransaction(cmd.Context(), func(ctx context.Context) error {
for i := range migrations {
migration := &migrations[i]
fileName, _ := r.fileNameBuilder.Up(migration.Version, false)
if err := svc.ApplyFileWithApplyTime(ctx, migration, fileName, applyTime); err != nil {
return err
}
}
return nil
})
if err != nil {
r.presenter.ShowReleaseError()
return err
}
r.presenter.ShowUpgradeSuccess(migrations.Len())
return nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package handler
import (
"context"
"github.com/raoptimus/db-migrator.go/internal/helper/console"
)
// Rollback handles the reverting of the latest release batch atomically in a single transaction.
// It identifies the batch by MAX(apply_time) and reverts all migrations in that batch.
type Rollback struct {
options *Options
presenter Presenter
fileNameBuilder FileNameBuilder
}
// NewRollback creates a new Rollback handler instance.
func NewRollback(
options *Options,
presenter Presenter,
fileNameBuilder FileNameBuilder,
) *Rollback {
return &Rollback{
options: options,
presenter: presenter,
fileNameBuilder: fileNameBuilder,
}
}
// Handle processes the rollback command to revert the latest release batch.
func (r *Rollback) Handle(cmd *Command, svc MigrationService) error {
migrations, err := svc.LatestReleaseMigrations(cmd.Context())
if err != nil {
return err
}
if migrations.Len() == 0 {
r.presenter.ShowNoMigrationsToRevert()
return nil
}
// Check all down files exist before proceeding
var missingVersions []string
for i := range migrations {
fileName, _ := r.fileNameBuilder.Down(migrations[i].Version, false)
exists, err := svc.FileExists(fileName)
if err != nil {
return err
}
if !exists {
missingVersions = append(missingVersions, migrations[i].Version)
}
}
if len(missingVersions) > 0 {
r.presenter.ShowMissingDownFiles(missingVersions)
return ErrMissingDownFiles
}
r.presenter.ShowDowngradePlan(migrations)
question := r.presenter.AskDowngradeConfirmation(migrations.Len())
if r.options.Interactive && !console.Confirm(question) {
return nil
}
err = svc.ExecInTransaction(cmd.Context(), func(ctx context.Context) error {
for i := range migrations {
migration := &migrations[i]
fileName, _ := r.fileNameBuilder.Down(migration.Version, false)
if err := svc.RevertFile(ctx, migration, fileName, false); err != nil {
return err
}
}
return nil
})
if err != nil {
r.presenter.ShowRollbackError()
return err
}
r.presenter.ShowDowngradeSuccess(migrations.Len())
return nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package handler
import (
"github.com/pkg/errors"
"github.com/raoptimus/db-migrator.go/internal/domain/service"
"github.com/raoptimus/db-migrator.go/internal/helper/dsn"
iohelp "github.com/raoptimus/db-migrator.go/internal/helper/io"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/dal/repository"
)
// NewMigrationService creates a new migration service instance with the specified logger, connection, and options.
// This function is used by the public API (service.go) to create migration service instances.
//
//nolint:ireturn // Returns interface by design for dependency inversion
func NewMigrationService(
options *Options,
logger Logger,
conn Connection,
) (MigrationService, error) {
// Parse DSN to extract credentials
parsed, err := dsn.Parse(options.DSN)
if err != nil {
return nil, errors.WithMessage(err, "parsing DSN")
}
// Create repository
var serviceRepo service.Repository
repo, err := repository.New(
conn,
&repository.Options{
TableName: options.TableName,
ClusterName: options.ClusterName,
Replicated: options.Replicated,
},
)
if err != nil {
return nil, err
}
if options.DryRun {
serviceRepo = service.NewDryRunRepository(repo)
} else {
serviceRepo = repo
}
// Create and return service
return service.NewMigration(
&service.Options{
MaxSQLOutputLength: options.MaxSQLOutputLength,
Directory: options.Directory,
Compact: options.Compact,
PlaceholderCustom: options.PlaceholderCustom,
ClusterName: options.ClusterName,
Username: parsed.Username,
Password: parsed.Password,
},
logger,
iohelp.StdFile,
serviceRepo,
), nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package handler
import (
"github.com/pkg/errors"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/dal/connection"
)
// ServiceHandler defines the interface for handlers that require MigrationService.
type ServiceHandler interface {
Handle(cmd *Command, svc MigrationService) error
}
// ServiceWrapHandler wraps a ServiceHandler, managing database connection
// and MigrationService lifecycle for each command execution.
type ServiceWrapHandler struct {
options *Options
logger Logger
handler ServiceHandler
}
// NewServiceWrapHandler creates a new ServiceWrapHandler instance.
func NewServiceWrapHandler(
options *Options,
logger Logger,
handler ServiceHandler,
) *ServiceWrapHandler {
return &ServiceWrapHandler{
options: options,
logger: logger,
handler: handler,
}
}
// Handle executes the command by creating database connection and MigrationService,
// then delegating to the wrapped handler.
func (w *ServiceWrapHandler) Handle(cmd *Command) error {
if w.options.DryRun {
w.options.Interactive = false
w.logger.Warn("[DRY RUN] No changes will be applied to the database.")
}
// Create database connection
conn, err := connection.Try(w.options.DSN, w.options.MaxConnAttempts)
if err != nil {
return errors.WithMessage(err, "failed to connect to database")
}
defer conn.Close()
// Create MigrationService
svc, err := NewMigrationService(w.options, w.logger, conn)
if err != nil {
return errors.WithMessage(err, "failed to create migration service")
}
return w.handler.Handle(cmd, svc)
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package handler
import (
"github.com/raoptimus/db-migrator.go/internal/domain/model"
"github.com/raoptimus/db-migrator.go/internal/helper/console"
)
// To handles the migration to a specific version.
type To struct {
options *Options
presenter Presenter
fileNameBuilder FileNameBuilder
}
// NewTo creates a new To handler instance.
func NewTo(
options *Options,
presenter Presenter,
fileNameBuilder FileNameBuilder,
) *To {
return &To{
options: options,
presenter: presenter,
fileNameBuilder: fileNameBuilder,
}
}
// Handle processes the to command to migrate to a specific version.
func (t *To) Handle(cmd *Command, svc MigrationService) error {
// Get and parse target version
if !cmd.Args.Present() {
return ErrTargetVersionRequired
}
targetInput := cmd.Args.First()
targetVersion, err := parseTargetVersion(targetInput)
if err != nil {
return err
}
// Check if target version is already applied by getting all applied migrations
// and checking if any migration's timestamp matches the target
allAppliedMigrations, err := svc.Migrations(cmd.Context(), 0)
if err != nil {
return err
}
targetApplied := false
for _, m := range allAppliedMigrations {
if extractTimestamp(m.Version) == targetVersion {
targetApplied = true
break
}
}
// Determine migration direction
if !targetApplied {
// Direction: UP (apply migrations)
return t.handleUpgrade(cmd, svc, targetVersion)
}
// Direction: DOWN (revert migrations)
return t.handleDowngrade(cmd, svc, targetVersion)
}
// handleUpgrade applies all migrations up to and including the target version.
func (t *To) handleUpgrade(cmd *Command, svc MigrationService, targetVersion string) error {
// Get all new migrations
allNewMigrations, err := svc.NewMigrations(cmd.Context())
if err != nil {
return err
}
// Filter: keep only migrations <= targetVersion
migrationsToApply := make(model.Migrations, 0)
for _, m := range allNewMigrations {
// Extract timestamp part for comparison (e.g., "251002_184510_change_scheme" -> "251002_184510")
migrationTimestamp := extractTimestamp(m.Version)
if migrationTimestamp <= targetVersion {
migrationsToApply = append(migrationsToApply, m)
}
}
// Check if there are migrations to apply
if len(migrationsToApply) == 0 {
t.presenter.ShowNoNewMigrations()
return nil
}
// Sort by version (ASC)
migrationsToApply.SortByVersion()
// Show migration plan
t.presenter.ShowUpgradePlan(migrationsToApply, len(migrationsToApply))
// User confirmation
question := t.presenter.AskUpgradeConfirmation(len(migrationsToApply))
if t.options.Interactive && !console.Confirm(question) {
return nil
}
// Apply migrations using helper function
applied, err := applyMigrations(cmd, svc, t.presenter, t.fileNameBuilder, migrationsToApply)
if err != nil {
return err
}
t.presenter.ShowUpgradeSuccess(applied)
return nil
}
// handleDowngrade reverts all migrations after the target version.
func (t *To) handleDowngrade(cmd *Command, svc MigrationService, targetVersion string) error {
// Get all applied migrations (unlimited)
allAppliedMigrations, err := svc.Migrations(cmd.Context(), 0)
if err != nil {
return err
}
// Filter: keep only migrations > targetVersion
migrationsToRevert := make(model.Migrations, 0)
for _, m := range allAppliedMigrations {
// Extract timestamp part for comparison (e.g., "251002_184510_change_scheme" -> "251002_184510")
migrationTimestamp := extractTimestamp(m.Version)
if migrationTimestamp > targetVersion {
migrationsToRevert = append(migrationsToRevert, m)
}
}
// Check if there are migrations to revert
if len(migrationsToRevert) == 0 {
// Already at target version
t.presenter.ShowNoMigrationsToRevert()
return nil
}
// Migrations are already sorted DESC from DB (correct order for rollback)
// Show migration plan
t.presenter.ShowDowngradePlan(migrationsToRevert)
// User confirmation
question := t.presenter.AskDowngradeConfirmation(len(migrationsToRevert))
if t.options.Interactive && !console.Confirm(question) {
return nil
}
// Revert migrations using helper function
reverted, err := revertMigrations(cmd, svc, t.presenter, t.fileNameBuilder, migrationsToRevert)
if err != nil {
return err
}
t.presenter.ShowDowngradeSuccess(reverted)
return nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package handler
import (
"github.com/raoptimus/db-migrator.go/internal/helper/console"
)
const (
defaultUpgradeLimit = 0
migratedUpSuccessfully = "Migrated up successfully"
)
// Upgrade handles the application of pending migrations to the database.
type Upgrade struct {
options *Options
presenter Presenter
fileNameBuilder FileNameBuilder
}
// NewUpgrade creates a new Upgrade handler instance.
func NewUpgrade(
options *Options,
presenter Presenter,
fileNameBuilder FileNameBuilder,
) *Upgrade {
return &Upgrade{
options: options,
presenter: presenter,
fileNameBuilder: fileNameBuilder,
}
}
// Handle processes the upgrade command to apply pending migrations.
func (u *Upgrade) Handle(cmd *Command, svc MigrationService) error {
limit, err := stepOrDefault(cmd, defaultUpgradeLimit)
if err != nil {
return err
}
migrations, err := svc.NewMigrations(cmd.Context())
if err != nil {
return err
}
totalNewMigrations := migrations.Len()
if totalNewMigrations == 0 {
u.presenter.ShowNoNewMigrations()
return nil
}
if limit > 0 && migrations.Len() > limit {
migrations = migrations[:limit]
}
u.presenter.ShowUpgradePlan(migrations, totalNewMigrations)
question := u.presenter.AskUpgradeConfirmation(migrations.Len())
if u.options.Interactive && !console.Confirm(question) {
return nil
}
var applied int
for i := range migrations {
migration := &migrations[i]
fileName, safely := u.fileNameBuilder.Up(migration.Version, false)
if err := svc.ApplyFile(cmd.Context(), migration, fileName, safely); err != nil {
u.presenter.ShowUpgradeError(applied, migrations.Len())
return err
}
applied++
}
u.presenter.ShowUpgradeSuccess(migrations.Len())
return nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package presenter
import (
"fmt"
"github.com/raoptimus/db-migrator.go/internal/domain/log"
"github.com/raoptimus/db-migrator.go/internal/domain/model"
"github.com/raoptimus/db-migrator.go/internal/helper/plural"
)
// Logger defines the interface for logging migration presentation messages.
//
//go:generate mockery
type Logger = log.Logger
// MigrationPresenter handles presentation logic for migration operations.
// It formats and displays migration-related information to the user.
type MigrationPresenter struct {
logger Logger
}
// NewMigrationPresenter creates a new MigrationPresenter instance.
func NewMigrationPresenter(logger Logger) *MigrationPresenter {
return &MigrationPresenter{
logger: logger,
}
}
// ShowUpgradePlan displays the plan for applying migrations.
// It shows the number of migrations to be applied and prints their list.
func (p *MigrationPresenter) ShowUpgradePlan(migrations model.Migrations, total int) {
if migrations.Len() == total {
p.logger.Warnf(
"Total %d new %s to be applied: \n",
migrations.Len(),
plural.Migration(migrations.Len()),
)
} else {
p.logger.Warnf(
"Total %d out of %d new %s to be applied: \n",
migrations.Len(),
total,
plural.Migration(total),
)
}
p.PrintMigrations(migrations, false)
}
// ShowDowngradePlan displays the plan for reverting migrations.
// It shows the number of migrations to be reverted and prints their list.
func (p *MigrationPresenter) ShowDowngradePlan(migrations model.Migrations) {
p.logger.Warnf("Total %d %s to be reverted: \n",
migrations.Len(),
plural.Migration(migrations.Len()),
)
p.PrintMigrations(migrations, false)
}
// PrintMigrations prints a list of migrations.
// If withTime is true, it includes the apply time for each migration.
func (p *MigrationPresenter) PrintMigrations(migrations model.Migrations, withTime bool) {
for _, migration := range migrations {
if withTime {
p.logger.Infof("\t(%s) %s\n", migration.ApplyTimeFormat(), migration.Version)
continue
}
p.logger.Infof("\t%s\n", migration.Version)
}
}
// AskUpgradeConfirmation returns a confirmation question for applying migrations.
func (p *MigrationPresenter) AskUpgradeConfirmation(count int) string {
return fmt.Sprintf("Apply the above %s?", plural.Migration(count))
}
// AskDowngradeConfirmation returns a confirmation question for reverting migrations.
func (p *MigrationPresenter) AskDowngradeConfirmation(count int) string {
return fmt.Sprintf("Revert the above %d %s?", count, plural.Migration(count))
}
// ShowUpgradeError displays a message when some migrations were applied before an error occurred.
func (p *MigrationPresenter) ShowUpgradeError(applied, total int) {
p.logger.Errorf("%d from %d %s applied.\n",
applied,
total,
plural.MigrationWas(applied),
)
p.logger.Error("The rest of the migrations are canceled.")
}
// ShowDowngradeError displays a message when some migrations were reverted before an error occurred.
func (p *MigrationPresenter) ShowDowngradeError(reverted, total int) {
p.logger.Errorf(
"%d from %d %s reverted.\n"+
"Migration failed. The rest of the migrations are canceled.\n",
reverted,
total,
plural.MigrationWas(reverted),
)
}
// ShowUpgradeSuccess displays a success message after all migrations have been applied.
func (p *MigrationPresenter) ShowUpgradeSuccess(count int) {
p.logger.Successf("%d %s applied\n", count, plural.MigrationWas(count))
p.logger.Success("Migrated up successfully")
}
// ShowDowngradeSuccess displays a success message after all migrations have been reverted.
func (p *MigrationPresenter) ShowDowngradeSuccess(count int) {
p.logger.Successf("%d %s reverted\n", count, plural.MigrationWas(count))
p.logger.Success("Migrated down successfully")
}
// ShowNoNewMigrations displays a message when there are no new migrations to apply.
func (p *MigrationPresenter) ShowNoNewMigrations() {
p.logger.Success("No new migrations found. Your system is up-to-date.")
}
// ShowNoMigrationsToRevert displays a message when there are no migrations to revert.
func (p *MigrationPresenter) ShowNoMigrationsToRevert() {
p.logger.Success("No migration has been done before.")
}
// ShowRedoPlan displays the plan for redoing migrations.
func (p *MigrationPresenter) ShowRedoPlan(migrations model.Migrations) {
p.logger.Warnf(
"Total %d %s to be redone: \n",
migrations.Len(),
plural.Migration(migrations.Len()),
)
p.PrintMigrations(migrations, false)
}
// AskRedoConfirmation returns a confirmation question for redoing migrations.
func (p *MigrationPresenter) AskRedoConfirmation(count int) string {
return fmt.Sprintf("Redo the above %d %s?\n", count, plural.Migration(count))
}
// ShowRedoError displays a message when redo operation failed.
func (p *MigrationPresenter) ShowRedoError() {
p.logger.Error("Migration failed. The rest of the migrations are canceled.\n")
}
// ShowRedoSuccess displays a success message after all migrations have been redone.
func (p *MigrationPresenter) ShowRedoSuccess(count int) {
p.logger.Warnf("%d %s redone.\n", count, plural.MigrationWas(count))
p.logger.Success("Migration redone successfully.\n")
}
// ShowHistoryHeader displays the header for migration history with limit.
func (p *MigrationPresenter) ShowHistoryHeader(count int) {
p.logger.Warnf(
"Showing the last %d %s: \n",
count,
plural.Migration(count),
)
}
// ShowAllHistoryHeader displays the header for all applied migrations.
func (p *MigrationPresenter) ShowAllHistoryHeader(count int) {
p.logger.Warnf(
"Total %d %s been applied before: \n",
count,
plural.MigrationHas(count),
)
}
// ShowNewMigrationsHeader displays the header for new migrations.
func (p *MigrationPresenter) ShowNewMigrationsHeader(count int) {
p.logger.Warnf(
"Found %d new %s \n",
count,
plural.Migration(count),
)
}
// ShowNewMigrationsLimitedHeader displays the header when showing limited new migrations.
func (p *MigrationPresenter) ShowNewMigrationsLimitedHeader(shown, total int) {
p.logger.Warnf(
"Showing %d out of %d new %s \n",
shown,
total,
plural.Migration(total),
)
}
// ShowReleaseError displays a message when release operation failed.
func (p *MigrationPresenter) ShowReleaseError() {
p.logger.Error("Release failed. All changes have been rolled back.\n")
}
// ShowRollbackError displays a message when rollback operation failed.
func (p *MigrationPresenter) ShowRollbackError() {
p.logger.Error("Rollback failed. Some changes may have been partially reverted.\n")
}
// ShowMissingDownFiles displays a message about missing down migration files.
func (p *MigrationPresenter) ShowMissingDownFiles(versions []string) {
p.logger.Error("Cannot rollback: missing down migration files for versions:\n")
for _, v := range versions {
p.logger.Errorf("\t%s\n", v)
}
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package builder
import (
"path/filepath"
)
const (
safelyUpSuffix = ".safe.up.sql"
safelyDownSuffix = ".safe.down.sql"
unsafelyUpSuffix = ".up.sql"
unsafelyDownSuffix = ".down.sql"
)
type FileName struct {
file File
migrationsDirectory string
}
func NewFileName(file File, migrationsDirectory string) *FileName {
return &FileName{
file: file,
migrationsDirectory: migrationsDirectory,
}
}
// Up builds a file name for migration update.
func (s *FileName) Up(version string, forceSafely bool) (fname string, safely bool) {
return s.build(version, safelyUpSuffix, unsafelyUpSuffix, forceSafely)
}
// Down builds a file name for migration downgrade.
func (s *FileName) Down(version string, forceSafely bool) (fname string, safely bool) {
return s.build(version, safelyDownSuffix, unsafelyDownSuffix, forceSafely)
}
func (s *FileName) build(
version,
safelySuffix,
unsafelySuffix string,
forceSafely bool,
) (fname string, safely bool) {
safelyFile := filepath.Join(s.migrationsDirectory, version+safelySuffix)
unsafelyFile := filepath.Join(s.migrationsDirectory, version+unsafelySuffix)
if exists, _ := s.file.Exists(unsafelyFile); exists && !forceSafely {
return unsafelyFile, false
}
return safelyFile, true
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package model
import (
"sort"
"time"
)
// Migration represents a domain migration record.
type Migration struct {
Version string
ApplyTime int64
BodySQL string
ExecutedSQL string
Release string
}
// ApplyTimeFormat returns the formatted apply time as a string in "YYYY-MM-DD HH:MM:SS" format.
func (m Migration) ApplyTimeFormat() string {
return time.Unix(m.ApplyTime, 0).Format("2006-01-02 15:04:05")
}
// Migrations is a collection of Migration records that implements sort.Interface.
type Migrations []Migration
// Len returns the number of migrations in the collection.
func (s Migrations) Len() int {
return len(s)
}
// Less reports whether the migration at index i should sort before the migration at index j.
func (s Migrations) Less(i, j int) bool {
return s[i].Version < s[j].Version
}
// Swap swaps the migrations at indexes i and j.
func (s Migrations) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
// SortByVersion sorts the migrations by their version string in ascending order.
func (s Migrations) SortByVersion() {
sort.Sort(s)
}
package service
import (
"context"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/dal/entity"
)
type DryRunRepository struct {
repo Repository
virtualTableCreated bool
}
func NewDryRunRepository(repo Repository) *DryRunRepository {
return &DryRunRepository{
repo: repo,
}
}
// ExistsMigration returns true if version of migration exists
func (d *DryRunRepository) ExistsMigration(ctx context.Context, version string) (bool, error) {
if d.virtualTableCreated {
return false, nil
}
return d.repo.ExistsMigration(ctx, version)
}
// Migrations returns applied migrations history.
func (d *DryRunRepository) Migrations(ctx context.Context, limit int) (entity.Migrations, error) {
if d.virtualTableCreated {
return nil, nil
}
return d.repo.Migrations(ctx, limit)
}
// HasMigrationHistoryTable returns true if migration history table exists.
func (d *DryRunRepository) HasMigrationHistoryTable(ctx context.Context) (exists bool, err error) {
if d.virtualTableCreated {
return true, nil
}
return d.repo.HasMigrationHistoryTable(ctx)
}
// InsertMigration inserts the new migration record.
func (d *DryRunRepository) InsertMigration(ctx context.Context, version string) error {
return nil
}
// RemoveMigration removes the migration record.
func (d *DryRunRepository) RemoveMigration(ctx context.Context, version string) error {
return nil
}
// ExecQuery executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
func (d *DryRunRepository) ExecQuery(ctx context.Context, query string, args ...any) error {
return nil
}
// ExecQueryTransaction executes a function within a database transaction.
func (d *DryRunRepository) ExecQueryTransaction(ctx context.Context, fnTx func(ctx context.Context) error) error {
return nil
}
// DropMigrationHistoryTable drops the migration history table from the database.
func (d *DryRunRepository) DropMigrationHistoryTable(ctx context.Context) error {
return nil
}
// CreateMigrationHistoryTable creates the migration history table in the database.
func (d *DryRunRepository) CreateMigrationHistoryTable(ctx context.Context) error {
d.virtualTableCreated = true
return nil
}
// MigrationsCount returns the total number of applied migrations.
func (d *DryRunRepository) MigrationsCount(ctx context.Context) (int, error) {
if d.virtualTableCreated {
return 0, nil
}
return d.repo.MigrationsCount(ctx)
}
// InsertMigrationWithApplyTime inserts the new migration record with an explicit apply time.
func (d *DryRunRepository) InsertMigrationWithApplyTime(ctx context.Context, version string, applyTime int64) error {
return nil
}
// MigrationsByMaxApplyTime returns migrations that share the maximum apply_time value.
func (d *DryRunRepository) MigrationsByMaxApplyTime(ctx context.Context) (entity.Migrations, error) {
if d.virtualTableCreated {
return nil, nil
}
return d.repo.MigrationsByMaxApplyTime(ctx)
}
// TableNameWithSchema returns the full table name including schema if applicable.
func (d *DryRunRepository) TableNameWithSchema() string {
if d.virtualTableCreated {
return ""
}
return d.repo.TableNameWithSchema()
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package mapper
import (
"github.com/raoptimus/db-migrator.go/internal/domain/model"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/dal/entity"
)
// EntityToDomain converts a DAL entity.Migration to a domain model.Migration.
func EntityToDomain(e entity.Migration) model.Migration {
return model.Migration{
Version: e.Version,
ApplyTime: e.ApplyTime,
}
}
// DomainToEntity converts a domain model.Migration to a DAL entity.Migration.
func DomainToEntity(m model.Migration) entity.Migration {
return entity.Migration{
Version: m.Version,
ApplyTime: m.ApplyTime,
}
}
// EntitiesToDomain converts a slice of DAL entity.Migration to domain model.Migrations.
func EntitiesToDomain(entities entity.Migrations) model.Migrations {
result := make(model.Migrations, len(entities))
for i, e := range entities {
result[i] = EntityToDomain(e)
}
return result
}
// DomainsToEntities converts a slice of domain model.Migration to DAL entity.Migrations.
func DomainsToEntities(migrations model.Migrations) entity.Migrations {
result := make(entity.Migrations, len(migrations))
for i, m := range migrations {
result[i] = DomainToEntity(m)
}
return result
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package service
import (
"context"
"fmt"
"path/filepath"
"regexp"
"strings"
"time"
"github.com/pkg/errors"
"github.com/raoptimus/db-migrator.go/internal/domain/model"
"github.com/raoptimus/db-migrator.go/internal/domain/service/mapper"
"github.com/raoptimus/db-migrator.go/internal/domain/validator"
"github.com/raoptimus/db-migrator.go/internal/helper/sqlio"
)
const (
baseMigration = "000000_000000_base"
defaultLimit = 10000
maxLimit = 100000
regexpFileNameGroupCount = 5
credentialMask = "****" // Mask for hiding credentials in output
)
// ErrMigrationVersionReserved occurs when attempting to apply or revert the reserved base migration version.
var ErrMigrationVersionReserved = errors.New("migration version reserved")
var (
regexpFileName = regexp.MustCompile(`^(\d{6}_?\d{6}[A-Za-z0-9_]+)\.((safe)\.)?(up|down)\.sql$`)
)
// Migration is the service handling migration operations.
// It orchestrates migration file processing, SQL execution, and history tracking.
type Migration struct {
options *Options
logger Logger
file File
repo Repository
}
// NewMigration creates a new Migration service instance.
// It accepts configuration options, logger, file system handler, and repository for database operations.
func NewMigration(
options *Options,
logger Logger,
file File,
repo Repository,
) *Migration {
return &Migration{
options: options,
logger: logger,
file: file,
repo: repo,
}
}
// InitializeTableHistory creates the migration history table if it does not exist.
// It inserts a base migration record after creating the table.
func (m *Migration) InitializeTableHistory(ctx context.Context) error {
exists, err := m.repo.HasMigrationHistoryTable(ctx)
if err != nil {
return err
}
if exists {
return nil
}
m.logger.Warnf("Creating migration history table %s...\n", m.repo.TableNameWithSchema())
if err := m.repo.CreateMigrationHistoryTable(ctx); err != nil {
return err
}
if err := m.repo.InsertMigration(ctx, baseMigration); err != nil {
if err2 := m.repo.DropMigrationHistoryTable(ctx); err2 != nil {
return errors.Wrap(err, err2.Error())
}
return err
}
m.logger.Success("Done")
return nil
}
// Migrations retrieves the list of applied migrations from the database.
// It excludes the base migration from the returned list and uses a default limit if none is provided.
func (m *Migration) Migrations(ctx context.Context, limit int) (model.Migrations, error) {
if err := m.InitializeTableHistory(ctx); err != nil {
return nil, err
}
if limit < 1 {
limit = defaultLimit
}
entities, err := m.repo.Migrations(ctx, limit)
if err != nil {
return nil, err
}
migrations := mapper.EntitiesToDomain(entities)
for i := range migrations {
if migrations[i].Version == baseMigration {
return append(migrations[:i], migrations[i+1:]...), nil
}
}
return migrations, nil
}
// NewMigrations retrieves the list of pending migrations that have not been applied yet.
// It compares migration files in the directory against the database history to identify new migrations.
func (m *Migration) NewMigrations(ctx context.Context) (model.Migrations, error) {
if err := m.InitializeTableHistory(ctx); err != nil {
return nil, err
}
entities, err := m.repo.Migrations(ctx, maxLimit)
if err != nil {
return nil, err
}
migrations := mapper.EntitiesToDomain(entities)
files, err := filepath.Glob(filepath.Join(m.options.Directory, "*.up.sql"))
if err != nil {
return nil, err
}
newMigrations := make(model.Migrations, 0)
var baseFilename string
for _, file := range files {
baseFilename = filepath.Base(file)
if err := validator.ValidateFileName(baseFilename); err != nil {
return nil, errors.Wrap(err, baseFilename)
}
groups := regexpFileName.FindStringSubmatch(baseFilename)
if len(groups) != regexpFileNameGroupCount {
return nil, fmt.Errorf("file name %s is invalid", baseFilename)
}
found := false
for _, migration := range migrations {
if migration.Version == baseMigration {
continue
}
if migration.Version == groups[1] {
found = true
break
}
}
if !found {
newMigrations = append(
newMigrations,
model.Migration{
Version: groups[1],
},
)
}
}
newMigrations.SortByVersion()
return newMigrations, err
}
// ApplySQL applies a migration by executing the provided SQL statements.
// It tracks execution time, logs progress, and records the migration in the history table.
// The safely parameter determines whether to execute statements within a transaction.
func (m *Migration) ApplySQL(
ctx context.Context,
safely bool,
version,
upSQL string,
) error {
if version == baseMigration {
return ErrMigrationVersionReserved
}
m.logger.Warnf("*** applying %s\n", version)
scanner := sqlio.NewScanner(strings.NewReader(upSQL))
start := time.Now()
err := m.apply(ctx, scanner, safely)
elapsedTime := time.Since(start)
if err != nil {
m.logger.Errorf("*** failed to apply %s (time: %.3fs)\n", version, elapsedTime.Seconds())
return err
}
if err := m.repo.InsertMigration(ctx, version); err != nil {
return err
}
// todo: save downSQL
m.logger.Successf("*** applied %s (time: %.3fs)\n", version, elapsedTime.Seconds())
return nil
}
// RevertSQL reverts a migration by executing the provided SQL statements.
// It tracks execution time, logs progress, and removes the migration from the history table.
// The safely parameter determines whether to execute statements within a transaction.
func (m *Migration) RevertSQL(
ctx context.Context,
safely bool,
version,
downSQL string,
) error {
if version == baseMigration {
return ErrMigrationVersionReserved
}
m.logger.Warnf("*** reverting %s\n", version)
scanner := sqlio.NewScanner(strings.NewReader(downSQL))
start := time.Now()
err := m.apply(ctx, scanner, safely)
elapsedTime := time.Since(start)
if err != nil {
m.logger.Errorf("*** failed to revert %s (time: %.3fs)\n", version, elapsedTime.Seconds())
return err
}
if err := m.repo.RemoveMigration(ctx, version); err != nil {
return err
}
m.logger.Warnf("*** reverted %s (time: %.3fs)\n", version, elapsedTime.Seconds())
return nil
}
// ApplyFile applies a migration by reading and executing SQL from a file.
// It tracks execution time, logs progress, and records the migration in the history table.
// The safely parameter determines whether to execute statements within a transaction.
func (m *Migration) ApplyFile(ctx context.Context, migration *model.Migration, fileName string, safely bool) error {
return m.applyFileCore(ctx, migration, fileName, safely, func(ctx context.Context, version string) error {
return m.repo.InsertMigration(ctx, version)
})
}
// ApplyFileWithApplyTime applies a migration by reading and executing SQL from a file
// with an explicit apply time. This is used by the release command to ensure all migrations
// in a release batch share the same apply_time for later rollback identification.
// The safely parameter is always false because release runs inside an outer transaction.
func (m *Migration) ApplyFileWithApplyTime(
ctx context.Context,
migration *model.Migration,
fileName string,
applyTime int64,
) error {
return m.applyFileCore(ctx, migration, fileName, false, func(ctx context.Context, version string) error {
return m.repo.InsertMigrationWithApplyTime(ctx, version, applyTime)
})
}
// applyFileCore contains the shared logic for applying a migration file.
// insertFn controls how the migration record is stored (with or without explicit applyTime).
func (m *Migration) applyFileCore(
ctx context.Context,
migration *model.Migration,
fileName string,
safely bool,
insertFn func(ctx context.Context, version string) error,
) error {
if migration.Version == baseMigration {
return ErrMigrationVersionReserved
}
m.logger.Warnf("*** applying %s\n", migration.Version)
scanner, err := m.scannerByFile(fileName)
if err != nil {
return err
}
defer func() {
// Ensure file is closed
if closeErr := scanner.Close(); closeErr != nil {
m.logger.Warnf("failed to close SQL scanner: %v", closeErr)
}
}()
start := time.Now()
err = m.apply(ctx, scanner, safely)
elapsedTime := time.Since(start)
if err != nil {
m.logger.Errorf("*** failed to apply %s (time: %.3fs)\n", migration.Version, elapsedTime.Seconds())
return err
}
if err := insertFn(ctx, migration.Version); err != nil {
return err
}
m.logger.Successf("*** applied %s (time: %.3fs)\n", migration.Version, elapsedTime.Seconds())
return nil
}
// RevertFile reverts a migration by reading and executing SQL from a file.
// It tracks execution time, logs progress, and removes the migration from the history table.
// The safely parameter determines whether to execute statements within a transaction.
func (m *Migration) RevertFile(ctx context.Context, migration *model.Migration, fileName string, safely bool) error {
if migration.Version == baseMigration {
return ErrMigrationVersionReserved
}
m.logger.Warnf("*** reverting %s\n", migration.Version)
scanner, err := m.scannerByFile(fileName)
if err != nil {
return err
}
defer func() {
// Ensure file is closed
if closeErr := scanner.Close(); closeErr != nil {
m.logger.Warnf("failed to close SQL scanner: %v", closeErr)
}
}()
start := time.Now()
err = m.apply(ctx, scanner, safely)
elapsedTime := time.Since(start)
if err != nil {
m.logger.Errorf("*** failed to revert %s (time: %.3fs)\n",
migration.Version, elapsedTime.Seconds())
return err
}
if err := m.repo.RemoveMigration(ctx, migration.Version); err != nil {
return err
}
m.logger.Warnf("*** reverted %s (time: %.3fs)\n", migration.Version, elapsedTime.Seconds())
return nil
}
// BeginCommand logs the start of a SQL command execution and returns the start time.
// It sanitizes credentials in the SQL output before logging.
func (m *Migration) BeginCommand(sqlQuery string) time.Time {
sqlQueryOutput := m.SQLQueryOutput(sqlQuery)
if !m.options.Compact {
m.logger.Infof(" > execute SQL: %s ...\n", sqlQueryOutput)
}
return time.Now()
}
// ExecQuery executes a SQL query and logs its execution time.
// It wraps the repository's ExecQuery method with timing and logging.
func (m *Migration) ExecQuery(ctx context.Context, sqlQuery string) error {
start := m.BeginCommand(sqlQuery)
if err := m.repo.ExecQuery(ctx, sqlQuery); err != nil {
return err
}
m.EndCommand(start)
return nil
}
// SQLQueryOutput prepares SQL query text for output by sanitizing credentials and truncating if needed.
// It applies credential masking and respects the maximum SQL output length setting.
func (m *Migration) SQLQueryOutput(sqlQuery string) string {
// First sanitize credentials
sqlQueryOutput := m.sanitizeCredentials(sqlQuery)
// Then apply length limit
if m.options.MaxSQLOutputLength > 0 && m.options.MaxSQLOutputLength < len(sqlQueryOutput) {
sqlQueryOutput = sqlQueryOutput[:m.options.MaxSQLOutputLength] + "..."
}
return sqlQueryOutput
}
// EndCommand logs the completion time of a SQL command execution.
// It is called after a command completes to log the elapsed time.
func (m *Migration) EndCommand(start time.Time) {
if m.options.Compact {
m.logger.Infof(" done (time: '%.3fs)\n", time.Since(start).Seconds())
}
}
// Exists checks whether a migration with the specified version has been applied.
// It queries the repository to determine if the migration exists in the history table.
func (m *Migration) Exists(ctx context.Context, version string) (bool, error) {
return m.repo.ExistsMigration(ctx, version)
}
// LatestReleaseMigrations returns migrations from the latest release batch,
// identified by the maximum apply_time value. It filters out the base migration.
func (m *Migration) LatestReleaseMigrations(ctx context.Context) (model.Migrations, error) {
if err := m.InitializeTableHistory(ctx); err != nil {
return nil, err
}
entities, err := m.repo.MigrationsByMaxApplyTime(ctx)
if err != nil {
return nil, err
}
migrations := mapper.EntitiesToDomain(entities)
// Filter out base migration
result := make(model.Migrations, 0, len(migrations))
for _, migration := range migrations {
if migration.Version != baseMigration {
result = append(result, migration)
}
}
return result, nil
}
// ExecInTransaction executes a function within a database transaction.
func (m *Migration) ExecInTransaction(ctx context.Context, fn func(ctx context.Context) error) error {
return m.repo.ExecQueryTransaction(ctx, fn)
}
// FileExists checks whether a file exists at the specified path.
func (m *Migration) FileExists(fileName string) (bool, error) {
return m.file.Exists(fileName)
}
func (m *Migration) apply(ctx context.Context, scanner *sqlio.Scanner, safely bool) error {
processScanFunc := func(ctx context.Context) error {
var sql string
for scanner.Scan() {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
sql = scanner.SQL()
if sql == "" {
continue
}
sql = strings.ReplaceAll(sql, "{cluster}", m.options.ClusterName)
sql = strings.ReplaceAll(sql, "{placeholder_custom}", m.options.PlaceholderCustom)
sql = strings.ReplaceAll(sql, "{username}", m.options.Username)
sql = strings.ReplaceAll(sql, "{password}", m.options.Password)
if err := m.ExecQuery(ctx, sql); err != nil {
return err
}
}
return scanner.Err()
}
var err error
if safely {
err = m.repo.ExecQueryTransaction(ctx, processScanFunc)
} else {
err = processScanFunc(ctx)
}
return err
}
func (m *Migration) scannerByFile(fileName string) (*sqlio.Scanner, error) {
exists, err := m.file.Exists(fileName)
if err != nil {
return nil, errors.Wrapf(err, "migration file %s does not exist", fileName)
}
if !exists {
return nil, fmt.Errorf("migration file %s does not exist", fileName)
}
f, err := m.file.Open(fileName)
if err != nil {
return nil, errors.Wrapf(err, "migration file %s does not read", fileName)
}
return sqlio.NewScanner(f), nil
}
// sanitizeCredentials replaces credential values with masks in SQL output.
// This prevents passwords and usernames from appearing in logs or console output.
func (m *Migration) sanitizeCredentials(sql string) string {
if m.options.Username == "" && m.options.Password == "" {
return sql
}
sanitized := sql
// Replace password if present
if m.options.Password != "" {
sanitized = strings.ReplaceAll(sanitized, m.options.Password, credentialMask)
}
// Replace username if present
if m.options.Username != "" {
sanitized = strings.ReplaceAll(sanitized, m.options.Username, credentialMask)
}
return sanitized
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package validator
import (
"regexp"
"time"
"github.com/pkg/errors"
)
const patternFileName = `^(?P<Year>\d{2})(?P<Month>\d{2})(?P<Day>\d{2})\_(?P<Hour>\d{2})(?P<Minute>\d{2})(?P<Second>\d{2})\_[a-z][a-z0-9\_\-]+(\.safe)?\.(up|down)\.sql$`
var (
ErrFileNameIsNotValid = errors.New("file name is not valid. File name must be eq pattern: YYMMDD_hhmmss_[a-z][a-z0-9\\_\\-]+(\\.safe)?\\.(up|down)\\.sql")
groupsLenFileName = len(regexpFileName.SubexpNames())
regexpFileName = regexp.MustCompile(patternFileName)
)
// ValidateFileName validates that a migration file name follows the required naming convention.
// The name must match the pattern: YYMMDD_hhmmss_name[.safe].(up|down).sql
// The timestamp must be valid and not in the future (accounting for timezone differences).
func ValidateFileName(name string) error {
if len(name) == 0 {
return ErrVersionIsNotValid
}
groups := regexpFileName.FindStringSubmatch(name)
if len(groups) < groupsLenFileName {
return ErrFileNameIsNotValid
}
yy, mm, dd, h, m, s := groups[1], groups[2], groups[3], groups[4], groups[5], groups[6]
dt, err := time.Parse(time.DateTime, "20"+yy+"-"+mm+"-"+dd+" "+h+":"+m+":"+s)
if err != nil {
return ErrFileNameIsNotValid
}
if dt.After(time.Now().Add(maxTZ)) {
return ErrFileNameIsNotValid
}
return nil
}
package validator
import "github.com/pkg/errors"
const maxIdentifierLen = 256
var ErrIdentifierIsNotValid = errors.New("identifier is not valid")
func ValidateIdentifier(id string) error {
if !isValidIdentifier(id) {
return ErrIdentifierIsNotValid
}
return nil
}
func isValidIdentifier(name string) bool {
if len(name) > maxIdentifierLen {
return false
}
for _, r := range name {
if !isValidIdentifierChar(r) {
return false
}
}
return true
}
func isValidIdentifierChar(r rune) bool {
return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
(r >= '0' && r <= '9') || r == '_'
}
package validator
import (
"strconv"
)
type TooShortError struct {
message string
}
func NewTooShortError(minLen int) *TooShortError {
return &TooShortError{
message: "This value should contain at least " + strconv.Itoa(minLen) + ".",
}
}
func (e *TooShortError) Error() string {
return e.message
}
type TooLongError struct {
message string
}
func NewTooLongError(maxLen int) *TooLongError {
return &TooLongError{
message: "This value should contain at most " + strconv.Itoa(maxLen) + ".",
}
}
func (e *TooLongError) Error() string {
return e.message
}
func ValidateStringLen(minLen, maxLen, value int) error {
if value < minLen {
return NewTooShortError(minLen)
}
if value > maxLen {
return NewTooLongError(maxLen)
}
return nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package validator
import (
"regexp"
"time"
"github.com/pkg/errors"
)
const maxTZ = 8 * time.Hour
const patternVersion = `^(?P<Year>\d{2})(?P<Month>\d{2})(?P<Day>\d{2})\_(?P<Hour>\d{2})(?P<Minute>\d{2})(?P<Second>\d{2})\_[a-z][a-z0-9\_\-]+$`
var (
ErrVersionIsNotValid = errors.New("Version is not valid. Version must be eq pattern: YYMMDD_hhmmss_[a-z0-9_]+")
regexpVersion = regexp.MustCompile(patternVersion)
groupsLenVersion = len(regexpVersion.SubexpNames())
)
// ValidateVersion validates that a migration version string follows the required format.
// The version must match the pattern: YYMMDD_hhmmss_name
// The timestamp must be valid and not in the future (accounting for timezone differences).
func ValidateVersion(version string) error {
if len(version) == 0 {
return ErrVersionIsNotValid
}
groups := regexpVersion.FindStringSubmatch(version)
if len(groups) < groupsLenVersion {
return ErrVersionIsNotValid
}
yy, mm, dd, h, m, s := groups[1], groups[2], groups[3], groups[4], groups[5], groups[6]
dt, err := time.Parse(time.DateTime, "20"+yy+"-"+mm+"-"+dd+" "+h+":"+m+":"+s)
if err != nil {
return ErrVersionIsNotValid
}
if dt.After(time.Now().Add(maxTZ)) {
return ErrVersionIsNotValid
}
return nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package console
import (
"bufio"
"fmt"
"os"
"strings"
)
// Confirm prompts the user for confirmation with the given format string.
// It returns true if the user responds with "y" or "yes", false for "n" or "no".
func Confirm(format string) bool {
return Confirmf(format)
}
// Confirmf prompts the user for confirmation with a formatted message.
// It accepts format and arguments similar to fmt.Printf.
// It returns true for affirmative responses ("y", "yes") and false otherwise.
func Confirmf(format string, args ...any) bool {
reader := bufio.NewReader(os.Stdin)
fmt.Printf(format+" [y/n]: ", args...)
response, err := reader.ReadString('\n')
if err != nil {
fmt.Println(err)
os.Exit(1)
}
response = strings.ToLower(strings.TrimSpace(response))
return response == "y" || response == "yes"
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package dsn
import "net/url"
// Host represents a single host endpoint in a DSN.
type Host struct {
Address string // host:port (full address)
Host string // host only
Port string // port only
}
// DSN represents a parsed Data Source Name with support for multiple hosts (cluster mode).
type DSN struct {
Driver string // database driver: clickhouse, postgres, mysql, tarantool
Hosts []Host // list of hosts (supports cluster mode with multiple hosts)
Database string // database name
Username string // authentication username
Password string // authentication password
Options url.Values // query parameters
Raw string // original DSN string
}
// Primary returns the primary (first) host.
// Returns empty Host if no hosts are configured.
func (d *DSN) Primary() Host {
if len(d.Hosts) > 0 {
return d.Hosts[0]
}
return Host{}
}
// PrimaryAddress returns the primary host address (host:port).
// Returns empty string if no hosts are configured.
func (d *DSN) PrimaryAddress() string {
return d.Primary().Address
}
// IsCluster returns true if DSN contains multiple hosts.
func (d *DSN) IsCluster() bool {
return len(d.Hosts) > 1
}
// HasCredentials returns true if DSN contains username or password.
func (d *DSN) HasCredentials() bool {
return d.Username != "" || d.Password != ""
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package dsn
import (
"net/url"
"strings"
"github.com/pkg/errors"
)
// ErrInvalidDSN indicates the DSN format is invalid.
var ErrInvalidDSN = errors.New("invalid DSN format")
// Parse parses a DSN string and returns a DSN structure.
// Supports cluster format with multiple hosts: driver://user:pass@host1:port,host2:port/database?options
func Parse(raw string) (*DSN, error) {
if raw == "" {
return nil, ErrInvalidDSN
}
// Extract driver from prefix (e.g., "clickhouse" from "clickhouse://...")
driver, rest, found := strings.Cut(raw, "://")
if !found {
return nil, errors.Wrap(ErrInvalidDSN, "missing driver prefix")
}
dsn := &DSN{
Driver: driver,
Raw: raw,
}
// Split at '@' to separate credentials from host part
credentialsPart, hostPart, hasAt := strings.Cut(rest, "@")
if !hasAt {
// No credentials, credentialsPart is actually hostPart
hostPart = credentialsPart
credentialsPart = ""
}
// Parse credentials (username:password)
if credentialsPart != "" {
username, password, _ := strings.Cut(credentialsPart, ":")
dsn.Username = username
dsn.Password = password
}
// Split host part from database and options
hostAndDB, optionsStr, _ := strings.Cut(hostPart, "?")
// Parse options (query parameters)
if optionsStr != "" {
var err error
dsn.Options, err = url.ParseQuery(optionsStr)
if err != nil {
return nil, errors.Wrap(err, "parsing options")
}
} else {
dsn.Options = make(url.Values)
}
// Split hosts from database path
hostsStr, database, _ := strings.Cut(hostAndDB, "/")
dsn.Database = database
// Parse multiple hosts (cluster support)
dsn.Hosts = parseHosts(hostsStr)
if len(dsn.Hosts) == 0 {
return nil, errors.Wrap(ErrInvalidDSN, "no hosts specified")
}
return dsn, nil
}
// parseHosts parses comma-separated hosts into Host structures.
func parseHosts(hostsStr string) []Host {
if hostsStr == "" {
return nil
}
hostStrs := strings.Split(hostsStr, ",")
hosts := make([]Host, 0, len(hostStrs))
for _, h := range hostStrs {
h = strings.TrimSpace(h)
if h == "" {
continue
}
host, port, _ := strings.Cut(h, ":")
hosts = append(hosts, Host{
Address: h,
Host: host,
Port: port,
})
}
return hosts
}
// MustParse parses DSN and panics on error. Useful for tests and initialization.
func MustParse(raw string) *DSN {
dsn, err := Parse(raw)
if err != nil {
panic(err)
}
return dsn
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package iohelp
import (
"io"
"os"
"github.com/pkg/errors"
)
// File provides file system operations including existence checks, reading, and creation.
type File struct{}
// StdFile is the standard File instance for general use.
var StdFile = NewFile()
// NewFile creates a new File instance.
func NewFile() *File {
return &File{}
}
// Exists checks whether a file exists at the specified path.
// It returns true if the file exists, false if it does not exist, and an error for other failures.
func (f *File) Exists(fileName string) (bool, error) {
if _, err := os.Stat(fileName); err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
return true, nil
}
// Open opens the named file for reading.
// It returns an io.ReadCloser for the file or an error if the operation fails.
func (f *File) Open(filename string) (io.ReadCloser, error) {
return os.Open(filename)
}
// ReadAll reads the entire contents of the named file.
// It returns the file contents as a byte slice or an error if the operation fails.
func (f *File) ReadAll(filename string) ([]byte, error) {
ff, err := f.Open(filename)
if err != nil {
return nil, err
}
defer ff.Close()
return io.ReadAll(ff)
}
// Create creates a new file with the specified name.
// If the file already exists, it will be truncated.
func (f *File) Create(filename string) error {
ff, err := os.Create(filename)
if err == nil {
err = ff.Close()
}
if err != nil {
return errors.Wrapf(err, "creating file %s", filename)
}
return nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package plural
const (
migration = "migration"
migrations = "migrations"
migrationWas = "migration was"
migrationsWere = "migrations were"
migrationHas = "migration has"
migrationsHave = "migrations have"
)
// NumberPlural returns the appropriate singular or plural form based on the count.
// If count is greater than 1, it returns the many form; otherwise, it returns the one form.
func NumberPlural(c int, one, many string) string {
if c > 1 {
return many
}
return one
}
// Migration returns "migration" for count 1 or "migrations" for count greater than 1.
func Migration(c int) string {
return NumberPlural(c, migration, migrations)
}
// MigrationWas returns "migration was" for count 1 or "migrations were" for count greater than 1.
func MigrationWas(c int) string {
return NumberPlural(c, migrationWas, migrationsWere)
}
// MigrationHas returns "migration has" for count 1 or "migrations have" for count greater than 1.
func MigrationHas(c int) string {
return NumberPlural(c, migrationHas, migrationsHave)
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package sqlio
import (
"bufio"
"bytes"
"fmt"
"io"
"strings"
)
const (
maxMigrationSize = 10 * 1 << 20
)
var (
multiStmtDelimiter = []byte(";")
psqlPLFuncDelimiter = []byte("$$")
)
// StartBufSize is the default starting size of the buffer used to scan and parse multi-statement migrations.
var StartBufSize = 4096
// Scanner scans SQL migration files and splits them into individual statements.
// It handles multi-statement SQL files with various delimiters including PostgreSQL $$ delimiters for functions.
type Scanner struct {
scanner *bufio.Scanner
sql string
err error
done bool
closer io.Closer // Store the closer if reader implements io.Closer
}
// NewScanner creates a new Scanner that reads from the provided io.Reader.
// The scanner splits SQL statements using semicolon delimiters while respecting PostgreSQL $$ delimiters.
func NewScanner(r io.Reader) *Scanner {
s := bufio.NewScanner(r)
s.Buffer(make([]byte, 0, StartBufSize), maxMigrationSize)
s.Split(splitWithDelimiter())
// Check if reader also implements io.Closer
var closer io.Closer
if rc, ok := r.(io.Closer); ok {
closer = rc
}
return &Scanner{
scanner: s,
closer: closer,
}
}
// SQL returns the current SQL statement that was scanned.
func (s *Scanner) SQL() string {
return s.sql
}
// Err returns the error, if any, that occurred during scanning.
func (s *Scanner) Err() error {
return s.err
}
// Scan advances the scanner to the next SQL statement.
// It returns true if a statement was found, false otherwise.
func (s *Scanner) Scan() bool {
if s.done {
return false
}
for s.scanner.Scan() {
s.sql = strings.TrimSpace(s.scanner.Text())
s.sql = strings.Trim(s.sql, ";")
if s.sql == "" {
continue
}
return true
}
s.err = s.scanner.Err()
return false
}
// Close closes the underlying reader if it implements io.Closer.
// It is safe to call Close multiple times.
// Returns nil if the reader doesn't implement io.Closer.
func (s *Scanner) Close() error {
if s.closer != nil {
err := s.closer.Close()
s.closer = nil // Prevent double-close
return err
}
return nil
}
func splitWithDelimiter() func(d []byte, atEOF bool) (int, []byte, error) {
return func(d []byte, atEOF bool) (int, []byte, error) {
// SplitFunc inspired by bufio.ScanLines() implementation
if atEOF {
if len(d) == 0 {
return 0, nil, nil
}
}
openPi, pLen := bytes.Index(d, psqlPLFuncDelimiter), len(psqlPLFuncDelimiter)
delI, delLen := bytes.Index(d, multiStmtDelimiter), len(multiStmtDelimiter)
switch {
case openPi > delI:
if len(d[:delI+delLen]) == 0 {
return 0, nil, nil
}
return delI + delLen, d[:delI+delLen], nil
case openPi >= 0 && openPi < delI:
closePi := bytes.Index(d[openPi+pLen:], psqlPLFuncDelimiter)
if closePi < 0 {
var err error
if atEOF {
err = fmt.Errorf("closed tag %s not found", psqlPLFuncDelimiter)
}
return 0, nil, err
}
offset := closePi + openPi + pLen
delI = bytes.Index(d[offset:], multiStmtDelimiter)
offset = offset + delI + delLen
if delI < 0 {
var err error
if atEOF {
err = fmt.Errorf("closed tag %s not found", multiStmtDelimiter)
}
return 0, nil, err
}
return offset, d[:offset], nil
case delI >= 0:
return delI + delLen, d[:delI+delLen], nil
default:
if atEOF {
return len(d), d, nil
}
return 0, nil, nil
}
}
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package timex
import (
"time"
)
// Time defines the interface for obtaining the current time.
// This abstraction allows for easier testing by enabling time mocking.
type Time interface {
Now() time.Time
}
// StdTime is the standard Time instance that uses time.Now.
var StdTime = New(time.Now)
// stdTime implements the Time interface using a custom time function.
type stdTime struct {
nowFunc func() time.Time
}
// New creates a new Time instance that uses the provided function to get the current time.
// This allows for dependency injection of time for testing purposes.
//
//nolint:ireturn,nolintlint // its ok
func New(nowFunc func() time.Time) Time {
return &stdTime{nowFunc: nowFunc}
}
// Now returns the current time by calling the configured time function.
func (s *stdTime) Now() time.Time {
return s.nowFunc()
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package urfavecli
import (
"context"
"github.com/raoptimus/db-migrator.go/internal/application/handler"
"github.com/urfave/cli/v3"
)
func Adapt(h handler.Handler) cli.ActionFunc {
return func(ctx context.Context, cmd *cli.Command) error {
internalCmd := &handler.Command{
Args: cmd.Args(),
}
return h.Handle(internalCmd.WithContext(ctx))
}
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package connection
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/pkg/errors"
"github.com/raoptimus/db-migrator.go/internal/helper/dsn"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/sqlex"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/sqlex/tarantool"
)
const durationBeforeNextConnAttempt = 1 * time.Second
var (
ErrTransactionAlreadyOpened = errors.New("transaction already opened")
)
type Connection struct {
driver Driver
dsn string
db SQLDB
ping bool
}
func New(dsnStr string) (*Connection, error) {
parsed, err := dsn.Parse(dsnStr)
if err != nil {
return nil, errors.WithMessage(err, "parsing DSN")
}
switch parsed.Driver {
case "clickhouse":
return clickhouse(dsnStr)
case "postgres":
return postgres(dsnStr)
case "mysql":
return mysql(dsnStr)
case "tarantool":
return tarantoolConn(dsnStr)
default:
return nil, fmt.Errorf("driver \"%s\" doesn't support", parsed.Driver)
}
}
// DSN returns the connection string.
func (c *Connection) DSN() string {
return c.dsn
}
// Driver returns the driver name used to connect to the database.
func (c *Connection) Driver() Driver {
return c.driver
}
// Ping checks connection
func (c *Connection) Ping() error {
if c.ping {
return nil
}
if err := c.db.Ping(); err != nil {
return errors.Wrapf(err, "ping %v connection: %v", c.Driver(), c.dsn)
}
c.ping = true
return nil
}
// QueryContext executes a query that returns rows, typically a SELECT.
// The args are for any placeholder parameters in the query.
//
//nolint:ireturn,nolintlint // its ok
func (c *Connection) QueryContext(ctx context.Context, query string, args ...any) (sqlex.Rows, error) {
return c.db.QueryContext(ctx, query, args...)
}
// ExecContext executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
//
//nolint:ireturn,nolintlint // its ok
func (c *Connection) ExecContext(ctx context.Context, query string, args ...any) (sqlex.Result, error) {
tx, err := TxFromContext(ctx)
if err != nil {
return c.db.ExecContext(ctx, query, args...)
}
// maybe need to clickhouse
// stmt, err := tx.PrepareContext(ctx, query)
// if err != nil {
// return nil, err
// }
//
// return stmt.ExecContext(ctx, args...)
return tx.ExecContext(ctx, query, args...)
}
// Transaction executes body in func txFn into transaction.
func (c *Connection) Transaction(ctx context.Context, txFn func(ctx context.Context) error) error {
if _, err := TxFromContext(ctx); err == nil {
return ErrTransactionAlreadyOpened
}
tx, err := c.db.BeginTx(ctx, nil)
if err != nil {
return errors.Wrap(err, "begin transaction")
}
if err := txFn(ContextWithTx(ctx, tx)); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return errors.Wrapf(err, "rollback failed: %v", rbErr)
}
return err
}
if err := tx.Commit(); err != nil {
_ = tx.Rollback()
return errors.Wrap(err, "commit transaction")
}
return nil
}
func (c *Connection) Close() error {
return c.db.Close()
}
func Try(dsn string, maxAttempts int) (*Connection, error) {
if maxAttempts < 1 {
maxAttempts = 1
}
var (
conn *Connection
err error
)
for i := 0; i < maxAttempts; i++ {
conn, err = New(dsn)
if err == nil {
if err = conn.Ping(); err == nil {
return conn, nil
}
}
if i < maxAttempts-1 {
time.Sleep(durationBeforeNextConnAttempt)
}
}
return nil, err
}
// clickhouse returns connection with clickhouse configuration.
func clickhouse(dsn string) (*Connection, error) {
db, err := sql.Open("clickhouse", dsn)
if err != nil {
return nil, err
}
return &Connection{
driver: DriverClickhouse,
dsn: dsn,
db: &sqlex.DB{DB: db},
}, nil
}
// postgres returns connection with postgres configuration.
func postgres(dsn string) (*Connection, error) {
db, err := sql.Open(DriverPostgres.String(), dsn)
if err != nil {
return nil, errors.Wrap(err, "open postgres connection")
}
return &Connection{
driver: DriverPostgres,
dsn: dsn,
db: &sqlex.DB{DB: db},
}, nil
}
// mysql returns connection with mysql configuration.
func mysql(dsn string) (*Connection, error) {
db, err := sql.Open(DriverMySQL.String(), dsn[8:])
if err != nil {
return nil, errors.Wrap(err, "open mysql connection")
}
return &Connection{
driver: DriverMySQL,
dsn: dsn,
db: &sqlex.DB{DB: db},
}, nil
}
// tarantool returns connection with tarantool configuration.
func tarantoolConn(dsn string) (*Connection, error) {
db, err := tarantool.Open(dsn)
if err != nil {
return nil, err
}
return &Connection{
driver: DriverTarantool,
dsn: dsn,
db: db,
}, nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package connection
import (
"context"
"github.com/pkg/errors"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/sqlex"
)
var (
// ErrNoTransaction indicates that no transaction was found in the context.
ErrNoTransaction = errors.New("no transaction")
)
// contextKey is a private type used for storing values in context to avoid collisions.
type contextKey int
const (
// contextKeyTX is the context key for storing transaction instances.
contextKeyTX contextKey = iota
)
// ContextWithTx returns a new context with the transaction stored in it.
func ContextWithTx(parent context.Context, v sqlex.Tx) context.Context {
return context.WithValue(parent, contextKeyTX, v)
}
// TxFromContext retrieves a transaction from the context.
// It returns ErrNoTransaction if no transaction is found.
//
//nolint:ireturn,nolintlint // its ok
func TxFromContext(parent context.Context) (sqlex.Tx, error) {
tx, ok := parent.Value(contextKeyTX).(sqlex.Tx)
if !ok {
return nil, ErrNoTransaction
}
return tx, nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package connection
// Driver represents a supported database driver type.
type Driver string
const (
// DriverClickhouse represents the ClickHouse database driver.
DriverClickhouse Driver = "clickhouse"
// DriverMySQL represents the MySQL database driver.
DriverMySQL Driver = "mysql"
// DriverPostgres represents the PostgreSQL database driver.
DriverPostgres Driver = "postgres"
// DriverTarantool represents the Tarantool database driver.
DriverTarantool Driver = "tarantool"
)
// String returns the string representation of the driver.
func (d Driver) String() string {
return string(d)
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package entity
import (
"sort"
"time"
)
// Migration represents a database migration record stored in the migration history table.
// It contains the version identifier and the timestamp when the migration was applied.
type Migration struct {
Version string `db:"version"`
ApplyTime int64 `db:"apply_time"`
// BodySQL string `db:"body_sql"`
// ExecutedSQL string `db:"executed_sql"`
// Release string `db:"release"`
}
// Migrations is a collection of Migration records that implements sort.Interface.
type Migrations []Migration
// Len returns the number of migrations in the collection.
// This method is required by sort.Interface.
func (s Migrations) Len() int {
return len(s)
}
// Less reports whether the migration at index i should sort before the migration at index j.
// Migrations are sorted by their version string in ascending order.
// This method is required by sort.Interface.
func (s Migrations) Less(i, j int) bool {
return s[i].Version < s[j].Version
}
// Swap swaps the migrations at indexes i and j.
// This method is required by sort.Interface.
func (s Migrations) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
// SortByVersion sorts the migrations by their version string in ascending order.
func (s Migrations) SortByVersion() {
sort.Sort(s)
}
// ApplyTimeFormat returns the formatted apply time as a string in "YYYY-MM-DD HH:MM:SS" format.
func (s Migration) ApplyTimeFormat() string {
return time.Unix(s.ApplyTime, 0).Format("2006-01-02 15:04:05")
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package repository
import (
"context"
"fmt"
"time"
"github.com/ClickHouse/clickhouse-go/v2"
"github.com/pkg/errors"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/dal/entity"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/sqlex"
)
// Clickhouse implements Repository interface for ClickHouse database.
// It handles migration history tracking and SQL execution for ClickHouse with support for clusters and replication.
type Clickhouse struct {
conn Connection
options *Options
}
// NewClickhouse creates a new Clickhouse repository instance.
// It returns a repository configured with the provided connection and options.
func NewClickhouse(conn Connection, options *Options) *Clickhouse {
return &Clickhouse{
conn: conn,
options: options,
}
}
// Migrations returns applied migrations history.
func (ch *Clickhouse) Migrations(ctx context.Context, limit int) (entity.Migrations, error) {
var (
q = `
SELECT version, apply_time
FROM ` + ch.dTableNameWithSchema() + `
WHERE is_deleted = 0
ORDER BY apply_time DESC, version DESC
LIMIT ?
`
migrations entity.Migrations
)
rows, err := ch.conn.QueryContext(ctx, q, limit)
if err != nil {
return nil, errors.Wrap(ch.dbError(err, q), "get migrations")
}
defer rows.Close()
for rows.Next() {
var (
version string
applyTime int64
)
if err := rows.Scan(&version, &applyTime); err != nil {
return nil, errors.Wrap(ch.dbError(err, q), "get migrations")
}
migrations = append(migrations,
entity.Migration{
Version: version,
ApplyTime: applyTime,
},
)
}
if err := rows.Err(); err != nil {
return nil, errors.Wrap(ch.dbError(err, q), "get migrations")
}
return migrations, nil
}
// HasMigrationHistoryTable returns true if migration history table exists.
func (ch *Clickhouse) HasMigrationHistoryTable(ctx context.Context) (exists bool, err error) {
var (
q = `
SELECT database, table
FROM system.columns
WHERE table = ? AND database = currentDatabase()
`
rows sqlex.Rows
)
rows, err = ch.conn.QueryContext(ctx, q, ch.dTableName())
if err != nil {
return false, errors.Wrap(ch.dbError(err, q), "get table schema")
}
defer rows.Close()
for rows.Next() {
var (
database string
table string
)
if err := rows.Scan(&database, &table); err != nil {
return false, errors.Wrap(ch.dbError(err, q), "get table schema")
}
//todo: scan columns to tableScheme
if table == ch.dTableName() {
return true, nil
}
}
if err := rows.Err(); err != nil {
return false, errors.Wrap(ch.dbError(err, q), "get table schema")
}
return false, nil
}
// InsertMigration inserts the new migration record.
func (ch *Clickhouse) InsertMigration(ctx context.Context, version string) error {
return ch.insertMigration(ctx, version, false)
}
// RemoveMigration removes the migration record.
func (ch *Clickhouse) RemoveMigration(ctx context.Context, version string) error {
return ch.insertMigration(ctx, version, true)
}
// ExecQuery executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
func (ch *Clickhouse) ExecQuery(ctx context.Context, query string, args ...any) error {
_, err := ch.conn.ExecContext(ctx, query, args...)
return ch.dbError(err, query)
}
// ExecQueryTransaction executes txFn in transaction.
// todo: называется ExecQuery но query не принимает. подумать
func (ch *Clickhouse) ExecQueryTransaction(ctx context.Context, txFn func(ctx context.Context) error) error {
return ch.conn.Transaction(ctx, txFn)
}
// CreateMigrationHistoryTable creates a new migration history table.
func (ch *Clickhouse) CreateMigrationHistoryTable(ctx context.Context) error {
var (
q string
extQ string
engine string
onCluster string
)
switch {
case ch.isUsedCluster():
onCluster = "ON CLUSTER " + ch.options.ClusterName
engine = "ReplicatedReplacingMergeTree('/clickhouse/tables/{shard}/" +
ch.options.ClusterName + "_" + ch.options.TableName + "', '{replica}', apply_time)"
extQ = fmt.Sprintf(`
CREATE TABLE %[2]s.d_%[3]s ON CLUSTER %[1]s AS %[2]s.%[3]s
ENGINE = Distributed('%[1]s', '%[2]s', %[3]s, cityHash64(toString(version)))
`,
ch.options.ClusterName,
ch.options.SchemaName,
ch.options.TableName,
)
case ch.options.Replicated:
engine = "ReplicatedReplacingMergeTree('/clickhouse/tables/{shard}/" +
ch.options.ClusterName + "_" + ch.options.TableName + "', '{replica}', apply_time)"
default:
engine = "ReplacingMergeTree(apply_time)"
}
q = fmt.Sprintf(
`
CREATE TABLE %s %s (
version String,
date Date DEFAULT toDate(apply_time),
apply_time UInt32,
is_deleted UInt8
) ENGINE = %s
PRIMARY KEY (version)
PARTITION BY (toYYYYMM(date))
ORDER BY (version)
SETTINGS index_granularity=8192
`,
ch.TableNameWithSchema(),
onCluster,
engine,
)
if _, err := ch.conn.ExecContext(ctx, q); err != nil {
return errors.Wrap(ch.dbError(err, q), "create migration history table")
}
if len(extQ) == 0 {
return nil
}
if _, err := ch.conn.ExecContext(ctx, extQ); err != nil {
return errors.Wrap(ch.dbError(err, extQ), "create migration history table")
}
return nil
}
// DropMigrationHistoryTable drops the migration history table.
func (ch *Clickhouse) DropMigrationHistoryTable(ctx context.Context) error {
if err := ch.dropTable(ctx, ch.TableNameWithSchema()); err != nil {
return err
}
if !ch.isUsedCluster() {
return nil
}
if err := ch.dropTable(ctx, ch.dTableNameWithSchema()); err != nil {
return err
}
return nil
}
// MigrationsCount returns the number of migrations
func (ch *Clickhouse) MigrationsCount(ctx context.Context) (int, error) {
q := "SELECT count(*) FROM " + ch.dTableNameWithSchema() + " WHERE is_deleted = 0"
var c int
if err := ch.QueryScalar(ctx, q, &c); err != nil {
return 0, err
}
return c, nil
}
// QueryScalar executes a query and scans a single scalar value into ptr.
// The ptr parameter must be a pointer to a scalar type (int, string, bool, etc).
func (ch *Clickhouse) QueryScalar(ctx context.Context, query string, ptr any) error {
if err := checkArgIsPtrAndScalar(ptr); err != nil {
return err
}
rows, err := ch.conn.QueryContext(ctx, query)
if err != nil {
return err
}
defer rows.Close()
if rows.Next() {
if err := rows.Scan(ptr); err != nil {
return ch.dbError(err, query)
}
}
if err := rows.Err(); err != nil {
return ch.dbError(err, query)
}
return nil
}
// ExistsMigration checks if a migration with the given version exists in the history table.
// It returns true if the migration record is found and not marked as deleted, false otherwise.
func (ch *Clickhouse) ExistsMigration(ctx context.Context, version string) (bool, error) {
q := "SELECT 1 FROM " + ch.dTableNameWithSchema() + " WHERE version = ? AND is_deleted = 0"
rows, err := ch.conn.QueryContext(ctx, q, version)
if err != nil {
return false, err
}
defer rows.Close()
var exists int
if rows.Next() {
if err := rows.Scan(&exists); err != nil {
return false, ch.dbError(err, q)
}
}
if err := rows.Err(); err != nil {
return false, ch.dbError(err, q)
}
return exists == 1, nil
}
// TableNameWithSchema returns the fully qualified table name with schema prefix.
// For ClickHouse, it returns database.table_name format.
func (ch *Clickhouse) TableNameWithSchema() string {
return ch.options.SchemaName + "." + ch.options.TableName
}
// dropTable drops a table by name, using cluster-aware syntax if cluster is configured.
func (ch *Clickhouse) dropTable(ctx context.Context, tableName string) error {
q := "DROP TABLE " + tableName
if ch.isUsedCluster() {
q += " ON CLUSTER " + ch.options.ClusterName + " NO DELAY"
}
if _, err := ch.conn.ExecContext(ctx, q); err != nil {
return errors.Wrap(ch.dbError(err, q), "drop migration history table")
}
return nil
}
// dTableName returns the distributed table name for cluster deployments.
// It adds the "d_" prefix to the table name when a cluster is used, otherwise returns the original table name.
func (ch *Clickhouse) dTableName() string {
if ch.isUsedCluster() {
return "d_" + ch.options.TableName
}
return ch.options.TableName
}
// dTableNameWithSchema returns the fully qualified distributed table name with schema prefix.
// For ClickHouse clusters, it returns database.d_table_name format.
func (ch *Clickhouse) dTableNameWithSchema() string {
return ch.options.SchemaName + "." + ch.dTableName()
}
// isUsedCluster checks if ClickHouse cluster mode is enabled.
// Returns true when a cluster name is configured and replication is not explicitly enabled.
func (ch *Clickhouse) isUsedCluster() bool {
return !ch.options.Replicated && len(ch.options.ClusterName) > 0
}
// insertMigration inserts migration record.
func (ch *Clickhouse) insertMigration(ctx context.Context, version string, isDeleted bool) error {
q := `
INSERT INTO ` + ch.dTableNameWithSchema() + ` (version, apply_time, is_deleted)
VALUES(?, ?, ?)
`
//nolint:gosec // overflow ok
now := uint32(time.Now().Unix())
var isDeletedInt int
if isDeleted {
isDeletedInt = 1
}
if err := ch.ExecQueryTransaction(ctx, func(ctx context.Context) error {
return ch.ExecQuery(ctx, q, version, now, isDeletedInt)
}); err != nil {
return errors.Wrap(ch.dbError(err, q), "insert migration")
}
return ch.optimizeTable(ctx)
}
// optimizeTable optimizes tables.
func (ch *Clickhouse) optimizeTable(ctx context.Context) error {
var q string
if ch.options.Replicated || ch.options.ClusterName == "" {
q = fmt.Sprintf("OPTIMIZE TABLE %s FINAL", ch.options.TableName)
} else {
q = fmt.Sprintf("OPTIMIZE TABLE %s ON CLUSTER %s FINAL", ch.options.TableName, ch.options.ClusterName)
}
if _, err := ch.conn.ExecContext(ctx, q); err != nil {
return errors.Wrap(ch.dbError(err, q), "optimize table")
}
return nil
}
// InsertMigrationWithApplyTime inserts the new migration record with an explicit apply time.
func (ch *Clickhouse) InsertMigrationWithApplyTime(ctx context.Context, version string, applyTime int64) error {
q := `
INSERT INTO ` + ch.dTableNameWithSchema() + ` (version, apply_time, is_deleted)
VALUES(?, ?, ?)
`
//nolint:gosec // overflow ok
if err := ch.ExecQueryTransaction(ctx, func(ctx context.Context) error {
return ch.ExecQuery(ctx, q, version, uint32(applyTime), 0)
}); err != nil {
return errors.Wrap(ch.dbError(err, q), "insert migration")
}
return ch.optimizeTable(ctx)
}
// MigrationsByMaxApplyTime returns migrations that share the maximum apply_time value.
func (ch *Clickhouse) MigrationsByMaxApplyTime(ctx context.Context) (entity.Migrations, error) {
q := `
SELECT version, apply_time
FROM ` + ch.dTableNameWithSchema() + `
WHERE is_deleted = 0 AND apply_time = (
SELECT MAX(apply_time) FROM ` + ch.dTableNameWithSchema() + ` WHERE is_deleted = 0
)
ORDER BY version DESC
`
rows, err := ch.conn.QueryContext(ctx, q)
if err != nil {
return nil, errors.Wrap(ch.dbError(err, q), "get migrations by max apply time")
}
defer rows.Close()
var migrations entity.Migrations
for rows.Next() {
var (
version string
applyTime int64
)
if err := rows.Scan(&version, &applyTime); err != nil {
return nil, errors.Wrap(ch.dbError(err, q), "get migrations by max apply time")
}
migrations = append(migrations, entity.Migration{
Version: version,
ApplyTime: applyTime,
})
}
if err := rows.Err(); err != nil {
return nil, errors.Wrap(ch.dbError(err, q), "get migrations by max apply time")
}
return migrations, nil
}
// dbError returns DBError is err is db error else returns got error.
func (ch *Clickhouse) dbError(err error, q string) error {
var clickEx *clickhouse.Exception
if !errors.As(err, &clickEx) {
return err
}
return errors.WithStack(&DBError{
Code: string(clickEx.Code),
Message: clickEx.Message,
Details: clickEx.StackTrace,
InternalQuery: q,
Cause: err,
})
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package repository
import (
"strings"
"github.com/pkg/errors"
)
var (
// ErrPtrValueMustBeAPointerAndScalar is returned when a value is not a pointer to a scalar type.
ErrPtrValueMustBeAPointerAndScalar = errors.New("ptr value must be a pointer and scalar")
)
// DBError represents a database error with additional metadata.
// It provides structured information about database errors including SQL state code, severity, and query details.
type DBError struct {
Code string
Severity string
Message string
Details string
InternalQuery string
Cause error
}
// Error returns the formatted error message including SQL state, severity, message, query, and details.
func (d *DBError) Error() string {
var sb strings.Builder
sb.WriteString("SQLSTATE[")
sb.WriteString(d.Code)
sb.WriteString("]: ")
if d.Severity != "" {
sb.WriteString(d.Severity)
sb.WriteString(": ")
}
sb.WriteString(d.Message)
sb.WriteString("\n")
if d.InternalQuery != "" {
sb.WriteString("The SQL being executed was: ")
sb.WriteString(d.InternalQuery)
sb.WriteString("\n")
}
if d.Details != "" {
sb.WriteString("Details: ")
sb.WriteString(d.Details)
sb.WriteString("\n")
}
return strings.TrimRight(sb.String(), "\n")
}
// Unwrap returns the underlying cause of the database error for error chain unwrapping.
func (d *DBError) Unwrap() error {
return d.Cause
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package repository
import (
"fmt"
"strings"
"github.com/ClickHouse/clickhouse-go/v2"
"github.com/go-sql-driver/mysql"
"github.com/pkg/errors"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/dal/connection"
)
const minTableNameParts = 2
// Factory defines the interface for creating repository instances for specific database drivers.
type Factory interface {
// Create creates a new repository instance for the given connection and options.
Create(conn Connection, options *Options) (Repository, error)
// Supports returns true if this factory supports the given database driver.
Supports(driver connection.Driver) bool
}
// FactoryRegistry manages a collection of repository factories for different database drivers.
type FactoryRegistry struct {
factories []Factory
}
// NewFactoryRegistry creates a new factory registry with all supported database driver factories.
func NewFactoryRegistry() *FactoryRegistry {
return &FactoryRegistry{
factories: []Factory{
&TarantoolFactory{},
&PostgresFactory{},
&MySQLFactory{},
&ClickhouseFactory{},
},
}
}
// Create creates a repository instance using the first factory that supports the connection's driver.
// It returns an error if no factory supports the driver.
//
//nolint:ireturn,nolintlint // its ok
func (r *FactoryRegistry) Create(conn Connection, options *Options) (Repository, error) {
for _, factory := range r.factories {
if factory.Supports(conn.Driver()) {
return factory.Create(conn, options)
}
}
return nil, fmt.Errorf("driver \"%s\" doesn't support", conn.Driver())
}
// ClickhouseFactory
// ClickhouseFactory creates repository instances for ClickHouse databases.
type ClickhouseFactory struct{}
// Supports returns true if the driver is ClickHouse.
func (f *ClickhouseFactory) Supports(driver connection.Driver) bool {
return driver == connection.DriverClickhouse
}
// Create creates a new ClickHouse repository instance with validated identifiers.
// It parses the DSN to extract database name and validates schema, table, and cluster names.
//
//nolint:ireturn,nolintlint // its ok
func (f *ClickhouseFactory) Create(conn Connection, options *Options) (Repository, error) {
opts, err := clickhouse.ParseDSN(conn.DSN())
if err != nil {
return nil, errors.WithMessage(err, "parsing dsn")
}
// Validate schema name extracted from DSN
if err := ValidateIdentifier(opts.Auth.Database); err != nil {
return nil, errors.Wrap(err, "invalid database name in DSN")
}
tableName, err := parseAndValidateTableName(options.TableName)
if err != nil {
return nil, err
}
// Validate cluster name
if err := ValidateIdentifier(options.ClusterName); err != nil {
return nil, errors.Wrap(err, "invalid cluster name")
}
return NewClickhouse(conn, &Options{
SchemaName: opts.Auth.Database,
TableName: tableName,
ClusterName: options.ClusterName,
Replicated: options.Replicated,
}), nil
}
// MySQLFactory
// MySQLFactory creates repository instances for MySQL databases.
type MySQLFactory struct{}
// Supports returns true if the driver is MySQL.
func (f *MySQLFactory) Supports(driver connection.Driver) bool {
return driver == connection.DriverMySQL
}
// Create creates a new MySQL repository instance with validated identifiers.
// It parses the DSN to extract database name and validates table and schema names.
//
//nolint:ireturn,nolintlint // its ok
func (f *MySQLFactory) Create(conn Connection, options *Options) (Repository, error) {
cfg, err := mysql.ParseDSN(conn.DSN())
if err != nil {
return nil, errors.WithMessage(err, "parsing dsn")
}
// Validate table name
if err := ValidateIdentifier(options.TableName); err != nil {
return nil, errors.Wrap(err, "invalid table name")
}
// Validate schema name extracted from DSN
if err := ValidateIdentifier(cfg.DBName); err != nil {
return nil, errors.Wrap(err, "invalid database name in DSN")
}
return NewMySQL(conn, &Options{
TableName: options.TableName,
SchemaName: cfg.DBName,
}), nil
}
// PostgresFactory
// PostgresFactory creates repository instances for PostgreSQL databases.
type PostgresFactory struct{}
// Supports returns true if the driver is PostgreSQL.
func (f *PostgresFactory) Supports(driver connection.Driver) bool {
return driver == connection.DriverPostgres
}
// Create creates a new PostgreSQL repository instance with validated identifiers.
// It parses schema and table names from the table name option and validates all identifiers.
//
//nolint:ireturn,nolintlint // its ok
func (f *PostgresFactory) Create(conn Connection, options *Options) (Repository, error) {
schemaName, tableName, err := parseAndValidateSchemaTableName(options.TableName, postgresDefaultSchema)
if err != nil {
return nil, err
}
return NewPostgres(conn, &Options{
TableName: tableName,
SchemaName: schemaName,
}), nil
}
// TarantoolFactory
// TarantoolFactory creates repository instances for Tarantool databases.
type TarantoolFactory struct{}
// Supports returns true if the driver is Tarantool.
func (f *TarantoolFactory) Supports(driver connection.Driver) bool {
return driver == connection.DriverTarantool
}
// Create creates a new Tarantool repository instance with validated table name.
//
//nolint:ireturn,nolintlint // its ok
func (f *TarantoolFactory) Create(conn Connection, options *Options) (Repository, error) {
// Validate table name
if err := ValidateIdentifier(options.TableName); err != nil {
return nil, errors.Wrap(err, "invalid table name")
}
return NewTarantool(conn, &Options{
TableName: options.TableName,
SchemaName: "",
}), nil
}
// parseAndValidateTableName parses and validates a table name that may contain schema prefix.
// For schema.table format, it validates both parts and returns only the table name.
// For simple table names, it validates and returns the name as-is.
func parseAndValidateTableName(tableName string) (string, error) {
if strings.Contains(tableName, ".") {
parts := strings.Split(tableName, ".")
if len(parts) < minTableNameParts {
return "", errors.New("invalid table name format: expected schema.table")
}
if err := ValidateIdentifier(parts[0]); err != nil {
return "", errors.Wrap(err, "invalid schema name in table name")
}
if err := ValidateIdentifier(parts[1]); err != nil {
return "", errors.Wrap(err, "invalid table name")
}
return parts[1], nil
}
if err := ValidateIdentifier(tableName); err != nil {
return "", errors.Wrap(err, "invalid table name")
}
return tableName, nil
}
// parseAndValidateSchemaTableName parses and validates a table name that may contain schema prefix.
// Returns schema name, table name, and error. If no schema prefix, uses defaultSchema.
func parseAndValidateSchemaTableName(tableName, defaultSchema string) (schema, table string, err error) {
if strings.Contains(tableName, ".") {
parts := strings.Split(tableName, ".")
if len(parts) < minTableNameParts {
return "", "", errors.New("invalid table name format: expected schema.table")
}
if err := ValidateIdentifier(parts[0]); err != nil {
return "", "", errors.Wrap(err, "invalid schema name in table name")
}
if err := ValidateIdentifier(parts[1]); err != nil {
return "", "", errors.Wrap(err, "invalid table name")
}
return parts[0], parts[1], nil
}
if err := ValidateIdentifier(tableName); err != nil {
return "", "", errors.Wrap(err, "invalid table name")
}
return defaultSchema, tableName, nil
}
package repository
import (
"time"
)
// checkArgIsPtrAndScalar validates that the provided argument is a pointer to a scalar type.
// It returns ErrPtrValueMustBeAPointerAndScalar if the argument is not a pointer to an integer,
// unsigned integer, float, boolean, string, or time.Time.
func checkArgIsPtrAndScalar(ptr any) error {
switch ptr.(type) {
case *int, *int8, *int16, *int32, *int64:
return nil
case *uint, *uint8, *uint16, *uint32, *uint64:
return nil
case *float32, *float64:
return nil
case *bool:
return nil
case *string:
return nil
case *time.Time:
return nil
default:
return ErrPtrValueMustBeAPointerAndScalar
}
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package repository
import (
"context"
"fmt"
"strconv"
"time"
"github.com/go-sql-driver/mysql"
"github.com/pkg/errors"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/dal/entity"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/sqlex"
)
// MySQL implements Repository interface for MySQL database.
// It handles migration history tracking and SQL execution for MySQL.
type MySQL struct {
conn Connection
options *Options
}
// NewMySQL creates a new MySQL repository instance.
// It returns a repository configured with the provided connection and options.
func NewMySQL(conn Connection, options *Options) *MySQL {
return &MySQL{
conn: conn,
options: options,
}
}
// Migrations returns applied migrations history.
func (m *MySQL) Migrations(ctx context.Context, limit int) (entity.Migrations, error) {
var (
q = fmt.Sprintf(
`
SELECT version, apply_time
FROM %s
ORDER BY apply_time DESC, version DESC
LIMIT ?`,
m.options.TableName,
)
migrations entity.Migrations
)
rows, err := m.conn.QueryContext(ctx, q, limit)
if err != nil {
return nil, errors.Wrap(m.dbError(err, q), "get migrations")
}
defer rows.Close()
for rows.Next() {
var (
version string
applyTime int64
)
if err := rows.Scan(&version, &applyTime); err != nil {
return nil, errors.Wrap(m.dbError(err, q), "get migrations")
}
migrations = append(migrations,
entity.Migration{
Version: version,
ApplyTime: applyTime,
},
)
}
if err := rows.Err(); err != nil {
return nil, errors.Wrap(m.dbError(err, q), "get migrations")
}
return migrations, nil
}
// HasMigrationHistoryTable returns true if migration history table exists.
func (m *MySQL) HasMigrationHistoryTable(ctx context.Context) (exists bool, err error) {
var (
q = `
SELECT EXISTS(
SELECT *
FROM information_schema.tables
WHERE table_name = ? AND table_schema = ?
)
`
rows sqlex.Rows
)
rows, err = m.conn.QueryContext(ctx, q, m.options.TableName, m.options.SchemaName)
if err != nil {
return false, errors.Wrap(m.dbError(err, q), "get schema table")
}
defer rows.Close()
for rows.Next() {
if err := rows.Scan(&exists); err != nil {
return false, errors.Wrap(m.dbError(err, q), "get schema table")
}
}
if err = rows.Err(); err != nil {
return false, errors.Wrap(m.dbError(err, q), "get schema table")
}
return exists, nil
}
// InsertMigration inserts the new migration record.
func (m *MySQL) InsertMigration(ctx context.Context, version string) error {
q := fmt.Sprintf(`
INSERT INTO %s (version, apply_time)
VALUES (?, ?)`,
m.options.TableName,
)
//nolint:gosec // overflow ok
now := uint32(time.Now().Unix())
if _, err := m.conn.ExecContext(ctx, q, version, now); err != nil {
return errors.Wrap(m.dbError(err, q), "insert migration")
}
return nil
}
// RemoveMigration removes the migration record.
func (m *MySQL) RemoveMigration(ctx context.Context, version string) error {
q := fmt.Sprintf(`DELETE FROM %s WHERE version = ?`, m.options.TableName)
if _, err := m.conn.ExecContext(ctx, q, version); err != nil {
return errors.Wrap(m.dbError(err, q), "remove migration")
}
return nil
}
// ExecQuery executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
func (m *MySQL) ExecQuery(ctx context.Context, query string, args ...any) error {
_, err := m.conn.ExecContext(ctx, query, args...)
return err
}
// ExecQueryTransaction executes a query in transaction without returning any rows.
// The args are for any placeholder parameters in the query.
func (m *MySQL) ExecQueryTransaction(ctx context.Context, txFn func(ctx context.Context) error) error {
return m.conn.Transaction(ctx, txFn)
}
// CreateMigrationHistoryTable creates a new migration history table.
func (m *MySQL) CreateMigrationHistoryTable(ctx context.Context) error {
q := fmt.Sprintf(
`
CREATE TABLE %s (
version VARCHAR(180) PRIMARY KEY,
apply_time INT
)
ENGINE=InnoDB
`,
m.options.TableName,
)
if _, err := m.conn.ExecContext(ctx, q); err != nil {
return errors.Wrap(m.dbError(err, q), "create migration history table")
}
return nil
}
// DropMigrationHistoryTable drops the migration history table.
func (m *MySQL) DropMigrationHistoryTable(ctx context.Context) error {
q := fmt.Sprintf(`DROP TABLE %s`, m.options.TableName)
if _, err := m.conn.ExecContext(ctx, q); err != nil {
return errors.Wrap(m.dbError(err, q), "drop migration history table")
}
return nil
}
// MigrationsCount returns the number of migrations
func (m *MySQL) MigrationsCount(ctx context.Context) (int, error) {
q := fmt.Sprintf(`SELECT count(*) FROM %s`, m.options.TableName)
var c int
if err := m.QueryScalar(ctx, q, &c); err != nil {
return 0, err
}
return c, nil
}
// QueryScalar executes a query and scans a single scalar value into ptr.
// The ptr parameter must be a pointer to a scalar type (int, string, bool, etc).
func (m *MySQL) QueryScalar(ctx context.Context, query string, ptr any) error {
if err := checkArgIsPtrAndScalar(ptr); err != nil {
return err
}
rows, err := m.conn.QueryContext(ctx, query)
if err != nil {
return err
}
defer rows.Close()
if rows.Next() {
if err := rows.Scan(ptr); err != nil {
return m.dbError(err, query)
}
}
if err := rows.Err(); err != nil {
return m.dbError(err, query)
}
return nil
}
// ExistsMigration checks if a migration with the given version exists in the history table.
// It returns true if the migration record is found, false otherwise.
func (m *MySQL) ExistsMigration(ctx context.Context, version string) (bool, error) {
q := fmt.Sprintf(`SELECT 1 FROM %s WHERE version = ?`, m.TableNameWithSchema())
rows, err := m.conn.QueryContext(ctx, q, version)
if err != nil {
return false, err
}
defer rows.Close()
var exists int
if rows.Next() {
if err := rows.Scan(&exists); err != nil {
return false, m.dbError(err, q)
}
}
if err := rows.Err(); err != nil {
return false, m.dbError(err, q)
}
return exists == 1, nil
}
// TableNameWithSchema returns the fully qualified table name with schema prefix.
// For MySQL, it returns schema_name.table_name format.
func (m *MySQL) TableNameWithSchema() string {
return m.options.SchemaName + "." + m.options.TableName
}
// InsertMigrationWithApplyTime inserts the new migration record with an explicit apply time.
func (m *MySQL) InsertMigrationWithApplyTime(ctx context.Context, version string, applyTime int64) error {
q := fmt.Sprintf(`
INSERT INTO %s (version, apply_time)
VALUES (?, ?)`,
m.options.TableName,
)
//nolint:gosec // overflow ok
if _, err := m.conn.ExecContext(ctx, q, version, uint32(applyTime)); err != nil {
return errors.Wrap(m.dbError(err, q), "insert migration")
}
return nil
}
// MigrationsByMaxApplyTime returns migrations that share the maximum apply_time value.
func (m *MySQL) MigrationsByMaxApplyTime(ctx context.Context) (entity.Migrations, error) {
q := fmt.Sprintf(`
SELECT version, apply_time
FROM %s
WHERE apply_time = (SELECT MAX(apply_time) FROM %s)
ORDER BY version DESC`,
m.options.TableName,
m.options.TableName,
)
rows, err := m.conn.QueryContext(ctx, q)
if err != nil {
return nil, errors.Wrap(m.dbError(err, q), "get migrations by max apply time")
}
defer rows.Close()
var migrations entity.Migrations
for rows.Next() {
var (
version string
applyTime int64
)
if err := rows.Scan(&version, &applyTime); err != nil {
return nil, errors.Wrap(m.dbError(err, q), "get migrations by max apply time")
}
migrations = append(migrations, entity.Migration{
Version: version,
ApplyTime: applyTime,
})
}
if err := rows.Err(); err != nil {
return nil, errors.Wrap(m.dbError(err, q), "get migrations by max apply time")
}
return migrations, nil
}
// dbError returns DBError is err is db error else returns got error.
func (m *MySQL) dbError(err error, q string) error {
var mysqlErr *mysql.MySQLError
if !errors.As(err, &mysqlErr) {
return err
}
return errors.WithStack(&DBError{
Code: strconv.Itoa(int(mysqlErr.Number)),
Message: mysqlErr.Message,
InternalQuery: q,
Cause: err,
})
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package repository
import (
"context"
"fmt"
"time"
"github.com/lib/pq"
"github.com/pkg/errors"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/dal/entity"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/sqlex"
)
const postgresDefaultSchema = "public"
// Postgres implements Repository interface for PostgreSQL database.
// It handles migration history tracking and SQL execution for PostgreSQL.
type Postgres struct {
conn Connection
options *Options
}
// NewPostgres creates a new Postgres repository instance.
// It returns a repository configured with the provided connection and options.
func NewPostgres(conn Connection, options *Options) *Postgres {
return &Postgres{
conn: conn,
options: options,
}
}
// Migrations returns applied migrations history.
func (p *Postgres) Migrations(ctx context.Context, limit int) (entity.Migrations, error) {
var (
q = fmt.Sprintf(
`
SELECT version, apply_time
FROM %s
ORDER BY apply_time DESC, version DESC
LIMIT $1`,
p.TableNameWithSchema(),
)
migrations entity.Migrations
)
rows, err := p.conn.QueryContext(ctx, q, limit)
if err != nil {
return nil, errors.Wrap(p.dbError(err, q), "get migrations")
}
defer rows.Close()
for rows.Next() {
var (
version string
applyTime int64
)
if err := rows.Scan(&version, &applyTime); err != nil {
return nil, errors.Wrap(p.dbError(err, q), "get migrations")
}
migrations = append(migrations,
entity.Migration{
Version: version,
ApplyTime: applyTime,
},
)
}
if err := rows.Err(); err != nil {
return nil, errors.Wrap(p.dbError(err, q), "get migrations")
}
return migrations, nil
}
// HasMigrationHistoryTable returns true if migration history table exists.
func (p *Postgres) HasMigrationHistoryTable(ctx context.Context) (exists bool, err error) {
var (
q = `
SELECT
d.nspname AS table_schema,
c.relname AS table_name
FROM pg_class c
LEFT JOIN pg_namespace d ON d.oid = c.relnamespace
WHERE (c.relname, d.nspname) = ($1, $2)
`
rows sqlex.Rows
)
rows, err = p.conn.QueryContext(ctx, q, p.options.TableName, p.options.SchemaName)
if err != nil {
return false, errors.Wrap(p.dbError(err, q), "get schema table")
}
defer rows.Close()
for rows.Next() {
var (
tableName string
schema string
)
if err := rows.Scan(&schema, &tableName); err != nil {
return false, errors.Wrap(p.dbError(err, q), "get schema table")
}
//todo: scan columns to tableScheme
if tableName == p.options.TableName {
return true, nil
}
}
if err := rows.Err(); err != nil {
return false, errors.Wrap(p.dbError(err, q), "get schema table")
}
return false, nil
}
// InsertMigration inserts the new migration record.
func (p *Postgres) InsertMigration(ctx context.Context, version string) error {
q := fmt.Sprintf(`
INSERT INTO %s (version, apply_time)
VALUES ($1, $2)`,
p.TableNameWithSchema(),
)
//nolint:gosec // overflow ok
now := uint32(time.Now().Unix())
if _, err := p.conn.ExecContext(ctx, q, version, now); err != nil {
return errors.Wrap(p.dbError(err, q), "insert migration")
}
return nil
}
// RemoveMigration removes the migration record.
func (p *Postgres) RemoveMigration(ctx context.Context, version string) error {
q := fmt.Sprintf(`DELETE FROM %s WHERE (version) = ($1)`, p.TableNameWithSchema())
if _, err := p.conn.ExecContext(ctx, q, version); err != nil {
return errors.Wrap(p.dbError(err, q), "remove migration")
}
return nil
}
// ExecQuery executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
func (p *Postgres) ExecQuery(ctx context.Context, query string, args ...any) error {
if _, err := p.conn.ExecContext(ctx, query, args...); err != nil {
return p.dbError(err, query)
}
return nil
}
// ExecQueryTransaction executes a query in transaction without returning any rows.
// The args are for any placeholder parameters in the query.
func (p *Postgres) ExecQueryTransaction(ctx context.Context, txFn func(ctx context.Context) error) error {
return p.conn.Transaction(ctx, txFn)
}
// CreateMigrationHistoryTable creates a new migration history table.
func (p *Postgres) CreateMigrationHistoryTable(ctx context.Context) error {
q := fmt.Sprintf(
`
CREATE TABLE %s (
version varchar(180) PRIMARY KEY,
apply_time integer
)
`,
p.TableNameWithSchema(),
)
if _, err := p.conn.ExecContext(ctx, q); err != nil {
return errors.Wrap(p.dbError(err, q), "create migration history table")
}
return nil
}
// DropMigrationHistoryTable drops the migration history table.
func (p *Postgres) DropMigrationHistoryTable(ctx context.Context) error {
q := fmt.Sprintf(`DROP TABLE %s`, p.TableNameWithSchema())
if _, err := p.conn.ExecContext(ctx, q); err != nil {
return errors.Wrap(p.dbError(err, q), "drop migration history table")
}
return nil
}
// MigrationsCount returns the number of migrations
func (p *Postgres) MigrationsCount(ctx context.Context) (int, error) {
q := fmt.Sprintf(`SELECT count(*) FROM %s`, p.TableNameWithSchema())
var c int
if err := p.QueryScalar(ctx, q, &c); err != nil {
return 0, err
}
return c, nil
}
// QueryScalar executes a query and scans a single scalar value into ptr.
// The ptr parameter must be a pointer to a scalar type (int, string, bool, etc).
func (p *Postgres) QueryScalar(ctx context.Context, query string, ptr any) error {
if err := checkArgIsPtrAndScalar(ptr); err != nil {
return err
}
rows, err := p.conn.QueryContext(ctx, query)
if err != nil {
return err
}
defer rows.Close()
if rows.Next() {
if err := rows.Scan(ptr); err != nil {
return p.dbError(err, query)
}
}
if err := rows.Err(); err != nil {
return p.dbError(err, query)
}
return nil
}
// ExistsMigration checks if a migration with the given version exists in the history table.
// It returns true if the migration record is found, false otherwise.
func (p *Postgres) ExistsMigration(ctx context.Context, version string) (bool, error) {
q := fmt.Sprintf(`SELECT EXISTS(SELECT 1 FROM %s WHERE version = $1)`, p.TableNameWithSchema())
rows, err := p.conn.QueryContext(ctx, q, version)
if err != nil {
return false, err
}
defer rows.Close()
var exists bool
if rows.Next() {
if err := rows.Scan(&exists); err != nil {
return false, p.dbError(err, q)
}
}
if err := rows.Err(); err != nil {
return false, p.dbError(err, q)
}
return exists, nil
}
// TableNameWithSchema returns the fully qualified table name with schema prefix.
// For PostgreSQL, it returns schema_name.table_name format.
func (p *Postgres) TableNameWithSchema() string {
return p.options.SchemaName + "." + p.options.TableName
}
// InsertMigrationWithApplyTime inserts the new migration record with an explicit apply time.
func (p *Postgres) InsertMigrationWithApplyTime(ctx context.Context, version string, applyTime int64) error {
q := fmt.Sprintf(`
INSERT INTO %s (version, apply_time)
VALUES ($1, $2)`,
p.TableNameWithSchema(),
)
//nolint:gosec // overflow ok
if _, err := p.conn.ExecContext(ctx, q, version, uint32(applyTime)); err != nil {
return errors.Wrap(p.dbError(err, q), "insert migration")
}
return nil
}
// MigrationsByMaxApplyTime returns migrations that share the maximum apply_time value.
func (p *Postgres) MigrationsByMaxApplyTime(ctx context.Context) (entity.Migrations, error) {
q := fmt.Sprintf(`
SELECT version, apply_time
FROM %s
WHERE apply_time = (SELECT MAX(apply_time) FROM %s)
ORDER BY version DESC`,
p.TableNameWithSchema(),
p.TableNameWithSchema(),
)
rows, err := p.conn.QueryContext(ctx, q)
if err != nil {
return nil, errors.Wrap(p.dbError(err, q), "get migrations by max apply time")
}
defer rows.Close()
var migrations entity.Migrations
for rows.Next() {
var (
version string
applyTime int64
)
if err := rows.Scan(&version, &applyTime); err != nil {
return nil, errors.Wrap(p.dbError(err, q), "get migrations by max apply time")
}
migrations = append(migrations, entity.Migration{
Version: version,
ApplyTime: applyTime,
})
}
if err := rows.Err(); err != nil {
return nil, errors.Wrap(p.dbError(err, q), "get migrations by max apply time")
}
return migrations, nil
}
// dbError returns DBError is err is db error else returns got error.
func (p *Postgres) dbError(err error, q string) error {
var pgErr *pq.Error
if !errors.As(err, &pgErr) {
return err
}
if q == "" {
q = pgErr.InternalQuery
}
return errors.WithStack(&DBError{
Code: pgErr.SQLState(),
Severity: pgErr.Severity,
Message: pgErr.Message,
Details: pgErr.Detail,
InternalQuery: q,
Cause: err,
})
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package repository
import (
"context"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/dal/entity"
)
// Repository defines the interface for database-specific migration history operations.
// It provides methods for managing migration history table and executing database queries.
type Repository interface {
// Migrations retrieves the list of applied migrations from the database, limited to the specified count.
Migrations(ctx context.Context, limit int) (entity.Migrations, error)
// HasMigrationHistoryTable checks if the migration history table exists in the database.
HasMigrationHistoryTable(ctx context.Context) (exists bool, err error)
// InsertMigration inserts a new migration version into the migration history table.
InsertMigration(ctx context.Context, version string) error
// RemoveMigration removes a migration version from the migration history table.
RemoveMigration(ctx context.Context, version string) error
// ExecQuery executes a query that doesn't return rows with the provided arguments.
ExecQuery(ctx context.Context, query string, args ...any) error
// QueryScalar executes a query that returns a single scalar value into the provided pointer.
QueryScalar(ctx context.Context, query string, ptr any) error
// ExecQueryTransaction executes a function within a database transaction.
ExecQueryTransaction(ctx context.Context, fnTx func(ctx context.Context) error) error
// DropMigrationHistoryTable drops the migration history table from the database.
DropMigrationHistoryTable(ctx context.Context) error
// CreateMigrationHistoryTable creates the migration history table in the database.
CreateMigrationHistoryTable(ctx context.Context) error
// MigrationsCount returns the total number of applied migrations in the database.
MigrationsCount(ctx context.Context) (int, error)
// ExistsMigration checks if a specific migration version exists in the migration history.
ExistsMigration(ctx context.Context, version string) (bool, error)
// TableNameWithSchema returns the fully qualified table name including schema.
TableNameWithSchema() string
// InsertMigrationWithApplyTime inserts a new migration version with an explicit apply time.
InsertMigrationWithApplyTime(ctx context.Context, version string, applyTime int64) error
// MigrationsByMaxApplyTime returns migrations that share the maximum apply_time value.
MigrationsByMaxApplyTime(ctx context.Context) (entity.Migrations, error)
}
// New creates repository by connection
//
//nolint:ireturn,nolintlint // it's ok
func New(conn Connection, options *Options) (Repository, error) {
registry := NewFactoryRegistry()
return registry.Create(conn, options)
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package repository
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"github.com/pkg/errors"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/dal/entity"
"github.com/tarantool/go-tarantool/v2"
)
const tarantoolIteratorLT = "LT"
const tarantoolIteratorEQ = "EQ"
// Tarantool implements Repository interface for Tarantool database.
// It handles migration history tracking and Lua script execution for Tarantool.
type Tarantool struct {
conn Connection
options *Options
}
// NewTarantool creates a new Tarantool repository instance.
// It returns a repository configured with the provided connection and options.
func NewTarantool(conn Connection, options *Options) *Tarantool {
return &Tarantool{
conn: conn,
options: options,
}
}
// Migrations returns applied migrations history.
func (p *Tarantool) Migrations(ctx context.Context, limit int) (entity.Migrations, error) {
var migrations entity.Migrations
q := fmt.Sprintf("return box.space.%s:select({}, {iterator='%s', limit = %d})",
p.TableNameWithSchema(),
tarantoolIteratorLT,
limit,
)
rows, err := p.conn.QueryContext(ctx, q)
if err != nil {
return nil, errors.Wrap(p.dbError(err, q), "get migrations")
}
defer rows.Close()
for rows.Next() {
var (
version string
applyTime int64
)
if err := rows.Scan(&version, &applyTime); err != nil {
return nil, errors.Wrap(p.dbError(err, q), "get migrations")
}
migrations = append(migrations,
entity.Migration{
Version: version,
ApplyTime: applyTime,
},
)
}
return migrations, nil
}
// HasMigrationHistoryTable returns true if migration history table exists.
func (p *Tarantool) HasMigrationHistoryTable(ctx context.Context) (exists bool, err error) {
q := fmt.Sprintf("return box.space.%s ~= nil", p.TableNameWithSchema())
rows, err := p.conn.QueryContext(ctx, q)
if err != nil {
return false, errors.Wrap(p.dbError(err, q), "get schema table")
}
defer rows.Close()
if rows.Next() {
if err := rows.Scan(&exists); err != nil {
return false, errors.Wrap(p.dbError(err, q), "get schema table")
}
if exists {
return true, nil
}
}
return false, nil
}
// InsertMigration inserts the new migration record.
func (p *Tarantool) InsertMigration(ctx context.Context, version string) error {
q := fmt.Sprintf("box.space.%s:insert({...})", p.TableNameWithSchema())
if _, err := p.conn.ExecContext(ctx, q, version, time.Now().Unix()); err != nil {
return errors.Wrap(p.dbError(err, q), "insert migration")
}
return nil
}
// RemoveMigration removes the migration record.
func (p *Tarantool) RemoveMigration(ctx context.Context, version string) error {
q := fmt.Sprintf("box.space.%s:delete(...)", p.TableNameWithSchema())
if _, err := p.conn.ExecContext(ctx, q, version); err != nil {
return errors.Wrap(p.dbError(err, q), "remove migration")
}
return nil
}
// ExecQuery executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
func (p *Tarantool) ExecQuery(ctx context.Context, query string, args ...any) error {
if _, err := p.conn.ExecContext(ctx, query, args...); err != nil {
return p.dbError(err, query)
}
return nil
}
// ExecQueryTransaction executes a query in transaction without returning any rows.
// The args are for any placeholder parameters in the query.
func (p *Tarantool) ExecQueryTransaction(ctx context.Context, txFn func(ctx context.Context) error) error {
return p.conn.Transaction(ctx, txFn)
}
// CreateMigrationHistoryTable creates a new migration history table.
func (p *Tarantool) CreateMigrationHistoryTable(ctx context.Context) error {
// create space
q := fmt.Sprintf("box.schema.space.create('%s', {if_not_exists = true})", p.TableNameWithSchema())
if _, err := p.conn.ExecContext(ctx, q); err != nil {
return errors.Wrap(p.dbError(err, q), "create migration history table")
}
// set space format
q = fmt.Sprintf("box.space.%s:format", p.TableNameWithSchema())
q += "({{'version',type = 'string',is_nullable = false},{'apply_time', type = 'unsigned', is_nullable = false}})"
if _, err := p.conn.ExecContext(ctx, q); err != nil {
return errors.Wrap(p.dbError(err, q), "create migration history table")
}
// create primary index
q = fmt.Sprintf("box.space.%s:create_index", p.TableNameWithSchema())
q += "('primary', {parts = {'version'}, if_not_exists = true})"
if _, err := p.conn.ExecContext(ctx, q); err != nil {
return errors.Wrap(p.dbError(err, q), "create migration history table")
}
// create secondary index
q = fmt.Sprintf("box.space.%s:create_index", p.TableNameWithSchema())
q += "('secondary', {parts = {{'apply_time'}, {'version'}}, if_not_exists = true})"
if _, err := p.conn.ExecContext(ctx, q); err != nil {
return errors.Wrap(p.dbError(err, q), "create migration history table")
}
return nil
}
// DropMigrationHistoryTable drops the migration history table.
func (p *Tarantool) DropMigrationHistoryTable(ctx context.Context) error {
q := fmt.Sprintf("box.space.%s:drop()", p.TableNameWithSchema())
if _, err := p.conn.ExecContext(ctx, q); err != nil {
return errors.Wrap(p.dbError(err, q), "drop migration history table")
}
return nil
}
// MigrationsCount returns the number of migrations
func (p *Tarantool) MigrationsCount(ctx context.Context) (int, error) {
var sb strings.Builder
sb.WriteString("return box.space.")
sb.WriteString(p.TableNameWithSchema())
sb.WriteString(":len()")
q := sb.String()
var c int
if err := p.QueryScalar(ctx, q, &c); err != nil {
return 0, err
}
return c, nil
}
// QueryScalar executes a query and scans a single scalar value into ptr.
// The ptr parameter must be a pointer to a scalar type (int, string, bool, etc).
func (p *Tarantool) QueryScalar(ctx context.Context, query string, ptr any) error {
if err := checkArgIsPtrAndScalar(ptr); err != nil {
return err
}
rows, err := p.conn.QueryContext(ctx, query)
if err != nil {
return p.dbError(err, query)
}
defer rows.Close()
if rows.Next() {
if err := rows.Scan(ptr); err != nil {
return p.dbError(err, query)
}
}
if err := rows.Err(); err != nil {
return p.dbError(err, query)
}
return nil
}
// ExistsMigration checks if a migration with the given version exists in the history space.
// It returns true if the migration record is found, false otherwise.
func (p *Tarantool) ExistsMigration(ctx context.Context, version string) (bool, error) {
q := fmt.Sprintf("box.space.%s:count('%s', {iterator='%s'})",
p.TableNameWithSchema(),
version,
tarantoolIteratorEQ,
)
rows, err := p.conn.QueryContext(ctx, q)
if err != nil {
return false, errors.WithMessage(err, "exists migration")
}
defer rows.Close()
var exists bool
if rows.Next() {
if err := rows.Scan(&exists); err != nil {
return false, p.dbError(err, q)
}
}
if err := rows.Err(); err != nil {
return false, p.dbError(err, q)
}
return exists, nil
}
// TableNameWithSchema returns the space name for Tarantool.
// For Tarantool, schema is not applicable, so it returns only the table name.
func (p *Tarantool) TableNameWithSchema() string {
return p.options.TableName
}
// InsertMigrationWithApplyTime inserts the new migration record with an explicit apply time.
func (p *Tarantool) InsertMigrationWithApplyTime(ctx context.Context, version string, applyTime int64) error {
q := fmt.Sprintf("box.space.%s:insert({...})", p.TableNameWithSchema())
if _, err := p.conn.ExecContext(ctx, q, version, applyTime); err != nil {
return errors.Wrap(p.dbError(err, q), "insert migration")
}
return nil
}
// MigrationsByMaxApplyTime returns migrations that share the maximum apply_time value.
func (p *Tarantool) MigrationsByMaxApplyTime(ctx context.Context) (entity.Migrations, error) {
// Find max apply_time using secondary index (reverse iterator, limit 1)
maxQ := fmt.Sprintf("return box.space.%s.index.secondary:max()",
p.TableNameWithSchema(),
)
maxRows, err := p.conn.QueryContext(ctx, maxQ)
if err != nil {
return nil, errors.Wrap(p.dbError(err, maxQ), "get max apply time")
}
defer maxRows.Close()
var maxVersion string
var maxApplyTime int64
if !maxRows.Next() {
return nil, nil
}
if err := maxRows.Scan(&maxVersion, &maxApplyTime); err != nil {
return nil, errors.Wrap(p.dbError(err, maxQ), "get max apply time")
}
// Select all records with this apply_time using secondary index
q := fmt.Sprintf("return box.space.%s.index.secondary:select({%d}, {iterator='%s'})",
p.TableNameWithSchema(),
maxApplyTime,
tarantoolIteratorEQ,
)
rows, err := p.conn.QueryContext(ctx, q)
if err != nil {
return nil, errors.Wrap(p.dbError(err, q), "get migrations by max apply time")
}
defer rows.Close()
var migrations entity.Migrations
for rows.Next() {
var (
version string
applyTime int64
)
if err := rows.Scan(&version, &applyTime); err != nil {
return nil, errors.Wrap(p.dbError(err, q), "get migrations by max apply time")
}
migrations = append(migrations, entity.Migration{
Version: version,
ApplyTime: applyTime,
})
}
return migrations, nil
}
// dbError returns DBError is err is db error else returns got error.
func (p *Tarantool) dbError(err error, q string) error {
var tErr tarantool.Error
if !errors.As(err, &tErr) {
return err
}
return errors.WithStack(&DBError{
Code: strconv.Itoa(int(tErr.Code)),
Severity: "ERR",
Message: tErr.Msg,
Details: "",
InternalQuery: q,
Cause: err,
})
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package repository
import "github.com/pkg/errors"
const maxIdentifierLen = 65000
var ErrInvalidIdentifier = errors.New("invalid SQL identifier")
// ValidateIdentifier validates SQL identifiers (table names, schema names, cluster names).
// Allowed characters: a-z, A-Z, 0-9, underscore (_)
// Empty string is allowed for optional fields.
// Returns ErrInvalidIdentifier if the identifier contains invalid characters or is too long.
func ValidateIdentifier(name string) error {
if name == "" {
return nil // empty is allowed for optional fields
}
if len(name) > maxIdentifierLen {
return errors.Wrapf(ErrInvalidIdentifier, "identifier too long: %d chars", len(name))
}
for _, r := range name {
if !isValidIdentifierChar(r) {
return errors.Wrapf(ErrInvalidIdentifier, "invalid character '%c' in: %s", r, name)
}
}
return nil
}
func isValidIdentifierChar(r rune) bool {
return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
(r >= '0' && r <= '9') || r == '_'
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package iofile
import (
"os"
"github.com/pkg/errors"
)
const fileModeExecutable = 0o755
// Exists checks whether a file or directory exists at the specified path.
// It returns true if the path exists, false otherwise.
func Exists(path string) bool {
if _, err := os.Stat(path); os.IsNotExist(err) {
return false
}
return true
}
// CreateDirectory creates a new directory at the specified path with executable permissions.
// If the directory already exists, it returns nil without error.
func CreateDirectory(path string) error {
if Exists(path) {
return nil
}
if err := os.Mkdir(path, fileModeExecutable); err != nil {
return errors.Wrapf(err, "creating directory %s", path)
}
return nil
}
// CreateFile creates a new empty file with the specified filename.
// If the file already exists, it will be truncated.
func CreateFile(filename string) error {
f, err := os.Create(filename)
if err == nil {
err = f.Close()
}
if err != nil {
return errors.Wrapf(err, "creating file %s", filename)
}
return nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package log
import (
"fmt"
"io"
"os"
"golang.org/x/term"
)
// Std is the standard logger instance that writes to stdout.
var Std = New(os.Stdout)
// Logger provides color-formatted logging with support for different log levels and TTY detection.
type Logger struct {
colors map[Level]colorFunc
writer io.Writer
isTTY bool
}
const (
// ColorBlack is the ANSI escape code for bold black text.
ColorBlack = "\033[1;30m"
// ColorRed is the ANSI escape code for bold red text.
ColorRed = "\033[1;31m"
// ColorGreen is the ANSI escape code for bold green text.
ColorGreen = "\033[1;32m"
// ColorYellow is the ANSI escape code for bold yellow text.
ColorYellow = "\033[1;33m"
// ColorBlue is the ANSI escape code for bold blue text.
ColorBlue = "\033[1;34m"
// ColorMagenta is the ANSI escape code for bold magenta text.
ColorMagenta = "\033[1;35m"
// ColorCyan is the ANSI escape code for bold cyan text.
ColorCyan = "\033[1;36m"
// ColorWhite is the ANSI escape code for bold white text.
ColorWhite = "\033[1;37m"
// ColorReset is the ANSI escape code to reset text color and formatting.
ColorReset = "\033[0m"
)
// Level represents a logging level for categorizing log messages.
type Level string
const (
// Info is the log level for informational messages.
Info Level = "info"
// Success is the log level for successful operation messages.
Success Level = "success"
// Warn is the log level for warning messages.
Warn Level = "warn"
// Error is the log level for error messages.
Error Level = "error"
)
type colorFunc func(a ...any) string
// New creates and returns a new Logger instance that writes to the provided io.Writer.
func New(w io.Writer) *Logger {
c := &Logger{
writer: w,
isTTY: term.IsTerminal(int(os.Stdout.Fd())),
colors: make(map[Level]colorFunc),
}
c.colors[Info] = c.colorFunc(ColorWhite)
c.colors[Success] = c.colorFunc(ColorGreen)
c.colors[Warn] = c.colorFunc(ColorYellow)
c.colors[Error] = c.colorFunc(ColorRed)
return c
}
// Infof logs a formatted informational message.
func (c *Logger) Infof(format string, args ...any) {
_, _ = fmt.Fprint(c.writer, c.colors[Info](fmt.Sprintf(format, args...)))
}
// Info logs an informational message.
func (c *Logger) Info(a ...any) {
_, _ = fmt.Fprintln(c.writer, c.colors[Info](a...))
}
// Successf logs a formatted success message.
func (c *Logger) Successf(format string, args ...any) {
_, _ = fmt.Fprint(c.writer, c.colors[Success](fmt.Sprintf(format, args...)))
}
// Success logs a success message.
func (c *Logger) Success(a ...any) {
_, _ = fmt.Fprintln(c.writer, c.colors[Success](a...))
}
// Warnf logs a formatted warning message.
func (c *Logger) Warnf(format string, args ...any) {
_, _ = fmt.Fprint(c.writer, c.colors[Warn](fmt.Sprintf(format, args...)))
}
// Warn logs a warning message.
func (c *Logger) Warn(a ...any) {
_, _ = fmt.Fprintln(c.writer, c.colors[Warn](a...))
}
// Error logs an error message.
func (c *Logger) Error(a ...any) {
_, _ = fmt.Fprintln(c.writer, c.colors[Error](a...))
}
// Errorf logs a formatted error message.
func (c *Logger) Errorf(format string, args ...any) {
_, _ = fmt.Fprint(c.writer, c.colors[Error](fmt.Sprintf(format, args...)))
}
// Fatal logs an error message and exits the program with code 1.
func (c *Logger) Fatal(a ...any) {
_, _ = fmt.Fprintln(c.writer, c.colors[Error](a...))
os.Exit(1)
}
// Fatalf logs a formatted error message and exits the program with code 1.
func (c *Logger) Fatalf(format string, args ...any) {
_, _ = fmt.Fprint(c.writer, c.colors[Error](fmt.Sprintf(format, args...)))
os.Exit(1)
}
func (c *Logger) colorFunc(code string) colorFunc {
if !c.isTTY {
return fmt.Sprint
}
return func(a ...any) string {
return fmt.Sprintf(code+"%s"+ColorReset, fmt.Sprint(a...))
}
}
package log
// NopLogger is a no-op implementation of the logger interface that discards all log messages.
type NopLogger struct{}
// Infof discards the formatted informational message.
func (c *NopLogger) Infof(_ string, _ ...any) {}
// Info discards the informational message.
func (c *NopLogger) Info(_ ...any) {}
// Successf discards the formatted success message.
func (c *NopLogger) Successf(_ string, _ ...any) {}
// Success discards the success message.
func (c *NopLogger) Success(_ ...any) {}
// Warnf discards the formatted warning message.
func (c *NopLogger) Warnf(_ string, _ ...any) {}
// Warn discards the warning message.
func (c *NopLogger) Warn(_ ...any) {}
// Errorf discards the formatted error message.
func (c *NopLogger) Errorf(_ string, _ ...any) {}
// Error discards the error message.
func (c *NopLogger) Error(_ ...any) {}
// Fatalf discards the formatted error message (no-op, does not exit).
func (c *NopLogger) Fatalf(_ string, _ ...any) {}
// Fatal discards the error message (no-op, does not exit).
func (c *NopLogger) Fatal(_ ...any) {}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package sqlex
import (
"context"
"database/sql"
)
// DB wraps the standard database/sql.DB and provides extended functionality
// for query execution and transaction management with custom return types.
type DB struct {
*sql.DB
}
// QueryContext executes a query that returns rows, typically a SELECT.
// The args are for any placeholder parameters in the query.
//
//nolint:ireturn,nolintlint // its ok
func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (Rows, error) {
rows, err := db.DB.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
return NewRowsWithSQLRows(rows), nil
}
// ExecContext executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
//
//nolint:ireturn,nolintlint // its ok
func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
return db.DB.ExecContext(ctx, query, args...)
}
// BeginTx starts a transaction.
//
// The provided context is used until the transaction is committed or rolled back.
// If the context is canceled, the sql package will roll back
// the transaction. [Tx.Commit] will return an error if the context provided to
// BeginTx is canceled.
//
// The provided [TxOptions] is optional and may be nil if defaults should be used.
// If a non-default isolation level is used that the driver doesn't support,
// an error will be returned.
//
//nolint:ireturn,nolintlint // its ok
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
tx, err := db.DB.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return NewTx(tx), nil
}
// Close closes the database connection and releases any associated resources.
func (db *DB) Close() error {
return db.DB.Close()
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package sqlex
import (
"database/sql"
"fmt"
"io"
"reflect"
"github.com/pkg/errors"
)
var (
// ErrNilPtrValue occurs when a nil pointer is passed to Scan.
ErrNilPtrValue = errors.New("nil ptr value")
// ErrPtrValueMustBeAPointer occurs when a non-pointer value is passed to Scan.
ErrPtrValueMustBeAPointer = errors.New("ptr value must be a pointer")
)
// Rows abstracts the interface for iterating over query result rows
// and scanning values into destination variables.
type Rows interface {
Next() bool
Scan(dest ...any) error
Err() error
Close() error
}
// RowsWithSlice implements the Rows interface using an in-memory slice of data.
// This is primarily used for testing and mock data scenarios.
type RowsWithSlice struct {
i int
rows []interface{}
}
// RowsWithSQLRows wraps the standard database/sql.Rows to implement the custom Rows interface.
type RowsWithSQLRows struct {
*sql.Rows
}
// NewRowsWithSQLRows creates a new RowsWithSQLRows wrapper around the provided sql.Rows.
func NewRowsWithSQLRows(rows *sql.Rows) *RowsWithSQLRows {
return &RowsWithSQLRows{Rows: rows}
}
// NewRowsWithSlice creates a new RowsWithSlice from the provided slice of row data.
func NewRowsWithSlice(rows []interface{}) *RowsWithSlice {
return &RowsWithSlice{
rows: rows,
}
}
// Next advances to the next row in the result set and returns true if a row is available.
func (v *RowsWithSlice) Next() bool {
return v.i < len(v.rows)
}
// Scan copies the columns from the current row into the values pointed at by dest.
// It performs type conversion and validation for in-memory slice data.
func (v *RowsWithSlice) Scan(dest ...any) error {
if v.i >= len(v.rows) {
return errors.WithStack(io.EOF)
}
for k := range dest {
destValPtr := dest[k]
if destValPtr == nil {
return errors.WithStack(ErrNilPtrValue)
}
destType := reflect.TypeOf(destValPtr)
if destType.Kind() != reflect.Pointer {
return errors.WithStack(ErrPtrValueMustBeAPointer)
}
destTypeElem := destType.Elem()
row := v.rows[v.i]
rowVals, ok := row.([]any)
if !ok {
rowVals = []any{row}
}
if len(dest) != len(rowVals) {
return errors.WithStack(
fmt.Errorf("sql: expected %d destination arguments in Scan, not %d",
len(rowVals),
len(dest),
))
}
rowVal := rowVals[k]
if rowVal == nil {
continue
}
rowType := reflect.TypeOf(rowVal)
rowTypeElem := rowType
if rowType.Kind() == reflect.Pointer {
rowTypeElem = rowType.Elem()
}
rowValOf := reflect.ValueOf(rowVal)
if rowTypeElem.Kind() == destTypeElem.Kind() {
reflect.ValueOf(destValPtr).Elem().Set(rowValOf)
continue
}
if rowValOf.CanConvert(destTypeElem) {
reflect.ValueOf(destValPtr).Elem().Set(rowValOf.Convert(destTypeElem))
continue
}
return errors.WithStack(
errors.Errorf("variable of type %T cannot be set to variable of type %T",
rowVal,
destValPtr,
))
}
v.i++
return nil
}
// Err returns the error, if any, that was encountered during iteration.
// For RowsWithSlice, this always returns nil.
func (v *RowsWithSlice) Err() error {
return nil
}
// Close closes the Rows, preventing further enumeration.
// For RowsWithSlice, this is a no-op that always returns nil.
func (v *RowsWithSlice) Close() error {
return nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package sqlex
import (
"context"
"database/sql"
)
// Stmt abstracts a prepared statement for executing queries with bound parameters.
type Stmt interface {
ExecContext(ctx context.Context, args ...any) (Result, error)
}
// stmt wraps the standard database/sql.Stmt to implement the custom Stmt interface.
type stmt struct {
*sql.Stmt
}
// ExecContext executes a prepared statement with the given arguments and returns a Result.
//
//nolint:ireturn,nolintlint // its ok
func (s *stmt) ExecContext(ctx context.Context, args ...any) (Result, error) {
return s.Stmt.ExecContext(ctx, args...)
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package tarantool
import (
"context"
"database/sql"
"time"
"github.com/pkg/errors"
"github.com/raoptimus/db-migrator.go/internal/helper/dsn"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/sqlex"
"github.com/tarantool/go-tarantool/v2"
)
const defaultQueryTimeout = 10 * time.Second
//go:generate mockery
// Doer defines the interface for executing Tarantool requests.
// It is implemented by both Connection and Stream, enabling mocking of stream operations.
type Doer interface {
Do(req tarantool.Request) *tarantool.Future
}
// Connection defines the interface for Tarantool database connection operations.
// It abstracts the go-tarantool library's connection functionality.
type Connection interface {
Doer
NewStream() (*tarantool.Stream, error)
Close() error
}
// DB wraps a Tarantool connection and implements the sqlex database interface.
// It provides query execution, transaction management, and connection lifecycle methods.
type DB struct {
conn Connection
}
// Open establishes a connection to a Tarantool database using the provided DSN.
// The DSN format is: tarantool://username:password@host:port
func Open(dsnStr string) (*DB, error) {
parsed, err := dsn.Parse(dsnStr)
if err != nil {
return nil, errors.WithMessage(err, "parsing tarantool DSN")
}
dialer := tarantool.NetDialer{
Address: parsed.PrimaryAddress(),
User: parsed.Username,
Password: parsed.Password,
}
opts := tarantool.Opts{
Timeout: defaultQueryTimeout, //todo: extract from query param dsn
}
conn, err := tarantool.Connect(context.Background(), dialer, opts)
if err != nil {
return nil, errors.Wrapf(err, "connect to tarantool at %s", parsed.PrimaryAddress())
}
return &DB{conn: conn}, nil
}
// Ping verifies the connection to the Tarantool database is alive.
func (db *DB) Ping() error {
_, err := db.conn.Do(tarantool.NewPingRequest()).GetResponse()
if err != nil {
return err
}
return nil
}
// QueryContext executes a Lua script that returns rows.
// The query parameter contains Lua code that will be evaluated on the Tarantool server.
//
//nolint:ireturn,nolintlint // its ok
func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (sqlex.Rows, error) {
req := tarantool.NewEvalRequest(query).Context(ctx)
if len(args) > 0 {
req = req.Args(args)
}
data, err := db.conn.
Do(req).
Get()
if err != nil {
return nil, err
}
if len(data) == 1 {
if indata, ok := data[0].([]any); ok {
data = indata
}
}
return sqlex.NewRowsWithSlice(data), nil
}
// ExecContext executes a Lua script that does not return rows.
// The query parameter contains Lua code that will be evaluated on the Tarantool server.
//
//nolint:ireturn,nolintlint // its ok
func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (sqlex.Result, error) {
req := tarantool.NewEvalRequest(query).Context(ctx)
if len(args) > 0 {
req = req.Args(args)
}
_, err := db.conn.Do(req).Get()
if err != nil {
return nil, err
}
return Done(true), nil
}
// BeginTx starts a new transaction on a Tarantool stream.
// Transaction options are ignored as Tarantool uses its own transaction semantics.
//
//nolint:ireturn,nolintlint // its ok
func (db *DB) BeginTx(ctx context.Context, _ *sql.TxOptions) (sqlex.Tx, error) {
stream, err := db.conn.NewStream()
if err != nil {
return nil, err
}
_, err = stream.
Do(tarantool.NewBeginRequest().Context(ctx).IsSync(true)).
Get()
if err != nil {
return nil, err
}
return NewTx(stream), nil
}
// Close closes the Tarantool database connection and releases any associated resources.
func (db *DB) Close() error {
if db.conn != nil {
return db.conn.Close()
}
return nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package tarantool
import (
"github.com/pkg/errors"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/sqlex"
)
// ErrIsNotSupportedByThisDriver indicates that the requested operation is not supported by the Tarantool driver.
var ErrIsNotSupportedByThisDriver = errors.New("is not supported by this driver")
// Done implements sqlex.Result for Tarantool operations.
// It represents the successful completion of an INSERT or UPDATE operation.
// Note that Tarantool does not provide LastInsertId or RowsAffected, so these methods return errors.
type Done bool
var _ sqlex.Result = Done(true)
// LastInsertId returns an error as this operation is not supported by Tarantool.
func (Done) LastInsertId() (int64, error) {
return 0, errors.WithMessage(ErrIsNotSupportedByThisDriver, "LastInsertId")
}
// RowsAffected returns an error as this operation is not supported by Tarantool.
func (v Done) RowsAffected() (int64, error) {
return 0, errors.WithMessage(ErrIsNotSupportedByThisDriver, "RowsAffected")
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package tarantool
import (
"context"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/sqlex"
"github.com/tarantool/go-tarantool/v2"
)
// Stmt wraps a Tarantool prepared statement and implements the sqlex.Stmt interface.
type Stmt struct {
stmt *tarantool.Prepared
}
// ExecContext executes a prepared statement with the given arguments.
//
//nolint:ireturn,nolintlint // its ok
func (s *Stmt) ExecContext(ctx context.Context, args ...any) (sqlex.Result, error) {
req := tarantool.NewExecutePreparedRequest(s.stmt).Context(ctx)
if len(args) > 0 {
req = req.Args(args)
}
_, err := s.stmt.Conn.Do(req).Get()
if err != nil {
return nil, err
}
return Done(true), nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package tarantool
import (
"context"
"sync"
"github.com/pkg/errors"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/sqlex"
"github.com/tarantool/go-tarantool/v2"
)
var (
// ErrTransactionAlreadyClosed indicates that an operation was attempted on a transaction that has already been committed or rolled back.
ErrTransactionAlreadyClosed = errors.New("transaction already closed")
)
// StreamDoer defines the interface for stream operations used by transactions.
// It is implemented by *tarantool.Stream and can be mocked for testing.
type StreamDoer interface {
Doer
// Conn returns the underlying connection for prepared statement operations.
// Note: This returns a concrete type from the tarantool library.
Conn() *tarantool.Connection
}
// streamWrapper wraps *tarantool.Stream to implement StreamDoer interface.
type streamWrapper struct {
*tarantool.Stream
}
// Conn returns the underlying connection from the stream.
func (s *streamWrapper) Conn() *tarantool.Connection {
return s.Stream.Conn
}
// tx wraps a Tarantool stream and implements the sqlex.Tx interface.
// It provides transaction operations including commit, rollback, and query execution.
type tx struct {
stream StreamDoer
closed bool
mu sync.RWMutex
}
// NewTx creates a new transaction wrapper around the provided Tarantool stream.
//
//nolint:ireturn,nolintlint // its ok
func NewTx(stream *tarantool.Stream) sqlex.Tx {
return &tx{
stream: &streamWrapper{Stream: stream},
}
}
// newTxWithStreamDoer creates a new transaction with a StreamDoer interface.
// This is primarily used for testing with mock implementations.
//
//nolint:unused // used in tests
func newTxWithStreamDoer(stream StreamDoer) *tx {
return &tx{
stream: stream,
}
}
// Commit commits the transaction, making all changes permanent.
// It returns ErrTransactionAlreadyClosed if the transaction has already been closed.
func (tx *tx) Commit() error {
tx.mu.Lock()
defer tx.mu.Unlock()
if tx.closed {
return errors.WithStack(ErrTransactionAlreadyClosed)
}
_, err := tx.stream.Do(tarantool.NewCommitRequest()).Get()
tx.closed = true
return err
}
// Rollback aborts the transaction, discarding all changes.
// It returns ErrTransactionAlreadyClosed if the transaction has already been closed.
func (tx *tx) Rollback() error {
tx.mu.Lock()
defer tx.mu.Unlock()
if tx.closed {
return errors.WithStack(ErrTransactionAlreadyClosed)
}
_, err := tx.stream.Do(tarantool.NewRollbackRequest()).Get()
tx.closed = true
return err
}
// ExecContext executes a Lua script within the transaction without returning rows.
// It returns ErrTransactionAlreadyClosed if the transaction has already been closed.
//
//nolint:ireturn,nolintlint // its ok
func (tx *tx) ExecContext(ctx context.Context, query string, args ...any) (sqlex.Result, error) {
tx.mu.RLock()
defer tx.mu.RUnlock()
if tx.closed {
return nil, errors.WithStack(ErrTransactionAlreadyClosed)
}
req := tarantool.NewEvalRequest(query).Context(ctx)
if len(args) > 0 {
req = req.Args(args)
}
if _, err := tx.stream.Do(req).Get(); err != nil {
return nil, err
}
return Done(true), nil
}
// PrepareContext creates a prepared statement for use within the transaction.
// It returns ErrTransactionAlreadyClosed if the transaction has already been closed.
//
//nolint:ireturn,nolintlint // its ok
func (tx *tx) PrepareContext(ctx context.Context, query string) (sqlex.Stmt, error) {
tx.mu.RLock()
defer tx.mu.RUnlock()
if tx.closed {
return nil, errors.WithStack(ErrTransactionAlreadyClosed)
}
resp, err := tx.stream.Do(tarantool.NewPrepareRequest(query).Context(ctx)).GetResponse()
if err != nil {
return nil, err
}
stmt, err := tarantool.NewPreparedFromResponse(tx.stream.Conn(), resp)
if err != nil {
return nil, err
}
return &Stmt{stmt: stmt}, nil
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package sqlex
import (
"context"
"database/sql"
)
// Tx abstracts database transaction operations including commit, rollback,
// query execution, and statement preparation.
type Tx interface {
Rollback() error
Commit() error
ExecContext(ctx context.Context, query string, args ...any) (Result, error)
PrepareContext(ctx context.Context, query string) (Stmt, error)
}
// sqlTx wraps the standard database/sql.Tx to implement the custom Tx interface.
type sqlTx struct {
*sql.Tx
}
// NewTx creates a new Tx wrapper around the provided sql.Tx.
//
//nolint:ireturn,nolintlint // its ok
func NewTx(tx *sql.Tx) Tx {
return &sqlTx{Tx: tx}
}
// PrepareContext creates a prepared statement for use within the transaction.
//
//nolint:ireturn,nolintlint // its ok
func (s *sqlTx) PrepareContext(ctx context.Context, query string) (Stmt, error) {
prepared, err := s.Tx.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
return &stmt{Stmt: prepared}, nil
}
// ExecContext executes a query within the transaction without returning any rows.
//
//nolint:ireturn,nolintlint // its ok
func (s *sqlTx) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
return s.Tx.ExecContext(ctx, query, args...)
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package thelp
import (
"regexp"
"strings"
)
var (
//nolint:gocritic // Ок
patternWhitespace = regexp.MustCompile(`[\s\r\n\t]+`)
patternSpaceCommaSpace = regexp.MustCompile(`\s*,\s+`)
)
// CompareSQL returns a function that compares an actual SQL string against the expected SQL string.
// Both strings are cleaned (whitespace normalized, etc.) before comparison.
func CompareSQL(expected string) func(actual string) bool {
return func(actual string) bool {
expected = CleanSQL(expected)
actual = CleanSQL(actual)
result := expected == actual
return result
}
}
// CleanSQL normalizes SQL text by removing extra whitespace, normalizing spacing around parentheses and commas.
// This allows for consistent comparison of SQL statements that may have different formatting.
func CleanSQL(sql string) string {
sql = strings.TrimSpace(sql)
sql = patternWhitespace.ReplaceAllString(sql, " ")
sql = patternSpaceCommaSpace.ReplaceAllString(sql, ",")
sql = strings.ReplaceAll(sql, " )", ")")
sql = strings.ReplaceAll(sql, "( ", "(")
return sql
}
/**
* This file is part of the raoptimus/db-migrator.go library
*
* @copyright Copyright (c) Evgeniy Urvantsev
* @license https://github.com/raoptimus/db-migrator.go/blob/master/LICENSE.md
* @link https://github.com/raoptimus/db-migrator.go
*/
package dbmigrator
import (
"context"
"github.com/pkg/errors"
"github.com/raoptimus/db-migrator.go/internal/application/handler"
"github.com/raoptimus/db-migrator.go/internal/domain/validator"
"github.com/raoptimus/db-migrator.go/internal/infrastructure/log"
)
type (
// Options configures the database migration service.
// It contains connection settings and database-specific parameters.
Options struct {
DSN string
// table name to history of migrations
TableName string
// cluster name to clickhouse
ClusterName string
// is replicated used to clickhouse?
Replicated bool
}
// DBService provides high-level operations for database migrations.
// It orchestrates the migration process using the underlying migration service.
DBService struct {
opts *handler.Options
logger Logger
conn Connection
}
)
// NewDBService creates a new database migration service with the provided options, connection, and logger.
// If logger is nil, a no-op logger will be used.
func NewDBService(opts *Options, conn Connection, logger Logger) (*DBService, error) {
if logger == nil {
logger = &log.NopLogger{}
}
options := &handler.Options{
DSN: opts.DSN,
MaxConnAttempts: 1,
Directory: "",
TableName: opts.TableName,
ClusterName: opts.ClusterName,
Replicated: opts.Replicated,
Compact: true,
Interactive: true,
MaxSQLOutputLength: 0,
}
if err := options.Validate(); err != nil {
return nil, err
}
return &DBService{
opts: options,
conn: conn,
logger: logger,
}, nil
}
// Upgrade apply changes to db. apply specific version of migration.
func (d *DBService) Upgrade(ctx context.Context, version, sql string, safety bool) error {
if err := validator.ValidateVersion(version); err != nil {
return err
}
serviceMigration, err := handler.NewMigrationService(d.opts, d.logger, d.conn)
if err != nil {
return err
}
exists, err := serviceMigration.Exists(ctx, version)
if err != nil {
return err
}
if exists {
return errors.WithStack(ErrMigrationAlreadyExists)
}
return serviceMigration.ApplySQL(ctx, safety, version, sql)
}
// Downgrade revert changes to db. revert specific version of migration.
func (d *DBService) Downgrade(ctx context.Context, version, sql string, safety bool) error {
if err := validator.ValidateVersion(version); err != nil {
return err
}
serviceMigration, err := handler.NewMigrationService(d.opts, d.logger, d.conn)
if err != nil {
return err
}
exists, err := serviceMigration.Exists(ctx, version)
if err != nil {
return err
}
if !exists {
return errors.WithStack(ErrAppliedMigrationNotFound)
}
return serviceMigration.RevertSQL(ctx, safety, version, sql)
}