// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package dbw import ( "math" "math/rand" "time" ) // Backoff defines an interface for providing a back off for retrying // transactions. See DoTx(...) type Backoff interface { Duration(attemptNumber uint) time.Duration } // ConstBackoff defines a constant backoff for retrying transactions. See // DoTx(...) type ConstBackoff struct { DurationMs time.Duration } // Duration is the constant backoff duration based on the retry attempt func (b ConstBackoff) Duration(attempt uint) time.Duration { return time.Millisecond * time.Duration(b.DurationMs) } // ExpBackoff defines an exponential backoff for retrying transactions. See DoTx(...) type ExpBackoff struct { testRand float64 } // Duration is the exponential backoff duration based on the retry attempt func (b ExpBackoff) Duration(attempt uint) time.Duration { var r float64 switch { case b.testRand > 0: r = b.testRand default: r = rand.Float64() } return time.Millisecond * time.Duration(math.Exp2(float64(attempt))*5*(r+0.5)) }
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package dbw import ( "sort" "gorm.io/gorm" "gorm.io/gorm/clause" ) // ColumnValue defines a column and it's assigned value for a database // operation. See: SetColumnValues(...) type ColumnValue struct { // Column name Column string // Value is the column's value Value interface{} } // Column represents a table Column type Column struct { // Name of the column Name string // Table name of the column Table string } func (c *Column) toAssignment(column string) clause.Assignment { return clause.Assignment{ Column: clause.Column{Name: column}, Value: clause.Column{Table: c.Table, Name: c.Name}, } } func rawAssignment(column string, value interface{}) clause.Assignment { return clause.Assignment{ Column: clause.Column{Name: column}, Value: value, } } // ExprValue encapsulates an expression value for a column assignment. See // Expr(...) to create these values. type ExprValue struct { Sql string Vars []interface{} } func (ev *ExprValue) toAssignment(column string) clause.Assignment { return clause.Assignment{ Column: clause.Column{Name: column}, Value: gorm.Expr(ev.Sql, ev.Vars...), } } // Expr creates an expression value (ExprValue) which can be used when setting // column values for database operations. See: Expr(...) // // Set name column to null example: // // SetColumnValues(map[string]interface{}{"name": Expr("NULL")}) // // Set exp_time column to N seconds from now: // // SetColumnValues(map[string]interface{}{"exp_time": Expr("wt_add_seconds_to_now(?)", 10)}) func Expr(expr string, args ...interface{}) ExprValue { return ExprValue{Sql: expr, Vars: args} } // SetColumnValues defines a map from column names to values for database // operations. func SetColumnValues(columnValues map[string]interface{}) []ColumnValue { keys := make([]string, 0, len(columnValues)) for key := range columnValues { keys = append(keys, key) } sort.Strings(keys) assignments := make([]ColumnValue, len(keys)) for idx, key := range keys { assignments[idx] = ColumnValue{Column: key, Value: columnValues[key]} } return assignments } // SetColumns defines a list of column (names) to update using the set of // proposed insert columns during an on conflict update. func SetColumns(names []string) []ColumnValue { assignments := make([]ColumnValue, len(names)) for idx, name := range names { assignments[idx] = ColumnValue{ Column: name, Value: Column{Name: name, Table: "excluded"}, } } return assignments } // OnConflict specifies how to handle alternative actions to take when an insert // results in a unique constraint or exclusion constraint error. type OnConflict struct { // Target specifies what conflict you want to define a policy for. This can // be any one of these: // Columns: the name of a specific column or columns // Constraint: the name of a unique constraint Target interface{} // Action specifies the action to take on conflict. This can be any one of // these: // DoNothing: leaves the conflicting record as-is // UpdateAll: updates all the columns of the conflicting record using the resource's data // []ColumnValue: update a set of columns of the conflicting record using the set of assignments Action interface{} } // Constraint defines database constraint name type Constraint string // Columns defines a set of column names type Columns []string // DoNothing defines an "on conflict" action of doing nothing type DoNothing bool // UpdateAll defines an "on conflict" action of updating all columns using the // proposed insert column values type UpdateAll bool
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package dbw import ( "fmt" "reflect" "strings" "gorm.io/gorm" ) // UpdateFields will create a map[string]interface of the update values to be // sent to the db. The map keys will be the field names for the fields to be // updated. The caller provided fieldMaskPaths and setToNullPaths must not // intersect. fieldMaskPaths and setToNullPaths cannot both be zero len. func UpdateFields(i interface{}, fieldMaskPaths []string, setToNullPaths []string) (map[string]interface{}, error) { const op = "dbw.UpdateFields" if i == nil { return nil, fmt.Errorf("%s: interface is missing: %w", op, ErrInvalidParameter) } if fieldMaskPaths == nil { fieldMaskPaths = []string{} } if setToNullPaths == nil { setToNullPaths = []string{} } if len(fieldMaskPaths) == 0 && len(setToNullPaths) == 0 { return nil, fmt.Errorf("%s: both fieldMaskPaths and setToNullPaths are zero len: %w", op, ErrInvalidParameter) } inter, maskPaths, nullPaths, err := Intersection(fieldMaskPaths, setToNullPaths) if err != nil { return nil, fmt.Errorf("%s: %w", op, ErrInvalidParameter) } if len(inter) != 0 { return nil, fmt.Errorf("%s: fieldMashPaths and setToNullPaths cannot intersect: %w", op, ErrInvalidParameter) } updateFields := map[string]interface{}{} // case sensitive update fields to values found := map[string]struct{}{} // we need something to keep track of found fields (case insensitive) val := reflect.Indirect(reflect.ValueOf(i)) structTyp := val.Type() for i := 0; i < structTyp.NumField(); i++ { if f, ok := maskPaths[strings.ToUpper(structTyp.Field(i).Name)]; ok { updateFields[f] = val.Field(i).Interface() found[strings.ToUpper(f)] = struct{}{} continue } if f, ok := nullPaths[strings.ToUpper(structTyp.Field(i).Name)]; ok { updateFields[f] = gorm.Expr("NULL") found[strings.ToUpper(f)] = struct{}{} continue } kind := structTyp.Field(i).Type.Kind() if kind == reflect.Struct || kind == reflect.Ptr { embType := structTyp.Field(i).Type // check if the embedded field is exported via CanInterface() if val.Field(i).CanInterface() { embVal := reflect.Indirect(reflect.ValueOf(val.Field(i).Interface())) // if it's a ptr to a struct, then we need a few more bits before proceeding. if kind == reflect.Ptr { embVal = val.Field(i).Elem() if !embVal.IsValid() { continue } embType = embVal.Type() if embType.Kind() != reflect.Struct { continue } } for embFieldNum := 0; embFieldNum < embType.NumField(); embFieldNum++ { if f, ok := maskPaths[strings.ToUpper(embType.Field(embFieldNum).Name)]; ok { updateFields[f] = embVal.Field(embFieldNum).Interface() found[strings.ToUpper(f)] = struct{}{} } if f, ok := nullPaths[strings.ToUpper(embType.Field(embFieldNum).Name)]; ok { updateFields[f] = gorm.Expr("NULL") found[strings.ToUpper(f)] = struct{}{} } } continue } } } if missing := findMissingPaths(setToNullPaths, found); len(missing) != 0 { return nil, fmt.Errorf("%s: null paths not found in resource: %s: %w", op, missing, ErrInvalidParameter) } if missing := findMissingPaths(fieldMaskPaths, found); len(missing) != 0 { return nil, fmt.Errorf("%s: field mask paths not found in resource: %s: %w", op, missing, ErrInvalidParameter) } return updateFields, nil } func findMissingPaths(paths []string, foundPaths map[string]struct{}) []string { notFound := []string{} for _, f := range paths { if _, ok := foundPaths[strings.ToUpper(f)]; !ok { notFound = append(notFound, f) } } return notFound } // Intersection is a case-insensitive search for intersecting values. Returns // []string of the Intersection with values in lowercase, and map[string]string // of the original av and bv, with the key set to uppercase and value set to the // original func Intersection(av, bv []string) ([]string, map[string]string, map[string]string, error) { const op = "dbw.Intersection" if av == nil { return nil, nil, nil, fmt.Errorf("%s: av is missing: %w", op, ErrInvalidParameter) } if bv == nil { return nil, nil, nil, fmt.Errorf("%s: bv is missing: %w", op, ErrInvalidParameter) } if len(av) == 0 && len(bv) == 0 { return []string{}, map[string]string{}, map[string]string{}, nil } s := []string{} ah := map[string]string{} bh := map[string]string{} for i := 0; i < len(av); i++ { ah[strings.ToUpper(av[i])] = av[i] } for i := 0; i < len(bv); i++ { k := strings.ToUpper(bv[i]) bh[k] = bv[i] if _, found := ah[k]; found { s = append(s, strings.ToLower(bh[k])) } } return s, ah, bh, nil } // BuildUpdatePaths takes a map of field names to field values, field masks, // fields allowed to be zero value, and returns both a list of field names to // update and a list of field names that should be set to null. func BuildUpdatePaths(fieldValues map[string]interface{}, fieldMask []string, allowZeroFields []string) (masks []string, nulls []string) { for f, v := range fieldValues { if !contains(fieldMask, f) { continue } switch { case isZero(v) && !contains(allowZeroFields, f): nulls = append(nulls, f) default: masks = append(masks, f) } } return masks, nulls } func isZero(i interface{}) bool { return i == nil || reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) }
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package dbw import ( "context" "fmt" "reflect" "strings" "sync/atomic" "gorm.io/gorm/clause" ) // OpType defines a set of database operation types type OpType int const ( // UnknownOp is an unknown operaton UnknownOp OpType = 0 // CreateOp is a create operation CreateOp OpType = 1 // UpdateOp is an update operation UpdateOp OpType = 2 // DeleteOp is a delete operation DeleteOp OpType = 3 // DefaultBatchSize is the default batch size for bulk operations like // CreateItems. This value is used if the caller does not specify a size // using the WithBatchSize(...) option. Note: some databases have a limit // on the number of query parameters (postgres is currently 64k and sqlite // is 32k) and/or size of a SQL statement (sqlite is currently 1bn bytes), // so this value should be set to a value that is less than the limits for // your target db. // See: // - https://www.postgresql.org/docs/current/limits.html // - https://www.sqlite.org/limits.html DefaultBatchSize = 1000 ) // VetForWriter provides an interface that Create and Update can use to vet the // resource before before writing it to the db. For optType == UpdateOp, // options WithFieldMaskPath and WithNullPaths are supported. For optType == // CreateOp, no options are supported type VetForWriter interface { VetForWrite(ctx context.Context, r Reader, opType OpType, opt ...Option) error } var nonCreateFields atomic.Value // InitNonCreatableFields sets the fields which are not setable using // via RW.Create(...) func InitNonCreatableFields(fields []string) { m := make(map[string]struct{}, len(fields)) for _, f := range fields { m[f] = struct{}{} } nonCreateFields.Store(m) } // NonCreatableFields returns the current set of fields which are not setable using // via RW.Create(...) func NonCreatableFields() []string { m := nonCreateFields.Load() if m == nil { return []string{} } fields := make([]string, 0, len(m.(map[string]struct{}))) for f := range m.(map[string]struct{}) { fields = append(fields, f) } return fields } // Create a resource in the db with options: WithDebug, WithLookup, // WithReturnRowsAffected, OnConflict, WithBeforeWrite, WithAfterWrite, // WithVersion, WithTable, and WithWhere. // // OnConflict specifies alternative actions to take when an insert results in a // unique constraint or exclusion constraint error. If WithVersion is used with // OnConflict, then the update for on conflict will include the version number, // which basically makes the update use optimistic locking and the update will // only succeed if the existing rows version matches the WithVersion option. // Zero is not a valid value for the WithVersion option and will return an // error. WithWhere allows specifying an additional constraint on the on // conflict operation in addition to the on conflict target policy (columns or // constraint). func (rw *RW) Create(ctx context.Context, i interface{}, opt ...Option) error { const op = "dbw.Create" if rw.underlying == nil { return fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) } if isNil(i) { return fmt.Errorf("%s: missing interface: %w", op, ErrInvalidParameter) } if err := raiseErrorOnHooks(i); err != nil { return fmt.Errorf("%s: %w", op, err) } opts := GetOpts(opt...) // these fields should be nil, since they are not writeable and we want the // db to manage them setFieldsToNil(i, NonCreatableFields()) if !opts.WithSkipVetForWrite { if vetter, ok := i.(VetForWriter); ok { if err := vetter.VetForWrite(ctx, rw, CreateOp); err != nil { return fmt.Errorf("%s: %w", op, err) } } } db := rw.underlying.wrapped.WithContext(ctx) if opts.WithOnConflict != nil { c := clause.OnConflict{} switch opts.WithOnConflict.Target.(type) { case Constraint: c.OnConstraint = string(opts.WithOnConflict.Target.(Constraint)) case Columns: columns := make([]clause.Column, 0, len(opts.WithOnConflict.Target.(Columns))) for _, name := range opts.WithOnConflict.Target.(Columns) { columns = append(columns, clause.Column{Name: name}) } c.Columns = columns default: return fmt.Errorf("%s: invalid conflict target %v: %w", op, reflect.TypeOf(opts.WithOnConflict.Target), ErrInvalidParameter) } switch opts.WithOnConflict.Action.(type) { case DoNothing: c.DoNothing = true case UpdateAll: c.UpdateAll = true case []ColumnValue: updates := opts.WithOnConflict.Action.([]ColumnValue) set := make(clause.Set, 0, len(updates)) for _, s := range updates { // make sure it's not one of the std immutable columns if contains([]string{"createtime", "publicid"}, strings.ToLower(s.Column)) { return fmt.Errorf("%s: cannot do update on conflict for column %s: %w", op, s.Column, ErrInvalidParameter) } switch sv := s.Value.(type) { case Column: set = append(set, sv.toAssignment(s.Column)) case ExprValue: set = append(set, sv.toAssignment(s.Column)) default: set = append(set, rawAssignment(s.Column, s.Value)) } } c.DoUpdates = set default: return fmt.Errorf("%s: invalid conflict action %v: %w", op, reflect.TypeOf(opts.WithOnConflict.Action), ErrInvalidParameter) } if opts.WithVersion != nil || opts.WithWhereClause != "" { where, args, err := rw.whereClausesFromOpts(ctx, i, opts) if err != nil { return fmt.Errorf("%s: %w", op, err) } whereConditions := db.Statement.BuildCondition(where, args...) c.Where = clause.Where{Exprs: whereConditions} } db = db.Clauses(c) } if opts.WithDebug { db = db.Debug() } if opts.WithTable != "" { db = db.Table(opts.WithTable) } if opts.WithBeforeWrite != nil { if err := opts.WithBeforeWrite(i); err != nil { return fmt.Errorf("%s: error before write: %w", op, err) } } tx := db.Create(i) if tx.Error != nil { return fmt.Errorf("%s: create failed: %w", op, tx.Error) } if opts.WithRowsAffected != nil { *opts.WithRowsAffected = tx.RowsAffected } if tx.RowsAffected > 0 && opts.WithAfterWrite != nil { if err := opts.WithAfterWrite(i, int(tx.RowsAffected)); err != nil { return fmt.Errorf("%s: error after write: %w", op, err) } } if err := rw.lookupAfterWrite(ctx, i, opt...); err != nil { return fmt.Errorf("%s: %w", op, err) } return nil } // CreateItems will create multiple items of the same type. Supported options: // WithBatchSize, WithDebug, WithBeforeWrite, WithAfterWrite, // WithReturnRowsAffected, OnConflict, WithVersion, WithTable, and WithWhere. // WithLookup is not a supported option. func (rw *RW) CreateItems(ctx context.Context, createItems interface{}, opt ...Option) error { const op = "dbw.CreateItems" switch { case rw.underlying == nil: return fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) case isNil(createItems): return fmt.Errorf("%s: missing items: %w", op, ErrInvalidParameter) } valCreateItems := reflect.ValueOf(createItems) switch { case valCreateItems.Kind() != reflect.Slice: return fmt.Errorf("%s: not a slice: %w", op, ErrInvalidParameter) case valCreateItems.Len() == 0: return fmt.Errorf("%s: missing items: %w", op, ErrInvalidParameter) } if err := raiseErrorOnHooks(createItems); err != nil { return fmt.Errorf("%s: %w", op, err) } opts := GetOpts(opt...) switch { case opts.WithLookup: return fmt.Errorf("%s: with lookup not a supported option: %w", op, ErrInvalidParameter) } var foundType reflect.Type for i := 0; i < valCreateItems.Len(); i++ { // verify that createItems are all the same type and do some bits on each item if i == 0 { foundType = reflect.TypeOf(valCreateItems.Index(i).Interface()) } currentType := reflect.TypeOf(valCreateItems.Index(i).Interface()) if currentType == nil { return fmt.Errorf("%s: unable to determine type of item %d: %w", op, i, ErrInvalidParameter) } if foundType != currentType { return fmt.Errorf("%s: create items contains disparate types. item %d is not a %s: %w", op, i, foundType.Name(), ErrInvalidParameter) } // these fields should be nil, since they are not writeable and we want the // db to manage them setFieldsToNil(valCreateItems.Index(i).Interface(), NonCreatableFields()) // vet each item if !opts.WithSkipVetForWrite { if vetter, ok := valCreateItems.Index(i).Interface().(VetForWriter); ok { if err := vetter.VetForWrite(ctx, rw, CreateOp); err != nil { return fmt.Errorf("%s: %w", op, err) } } } } if opts.WithBeforeWrite != nil { if err := opts.WithBeforeWrite(createItems); err != nil { return fmt.Errorf("%s: error before write: %w", op, err) } } db := rw.underlying.wrapped.WithContext(ctx) if opts.WithOnConflict != nil { c := clause.OnConflict{} switch opts.WithOnConflict.Target.(type) { case Constraint: c.OnConstraint = string(opts.WithOnConflict.Target.(Constraint)) case Columns: columns := make([]clause.Column, 0, len(opts.WithOnConflict.Target.(Columns))) for _, name := range opts.WithOnConflict.Target.(Columns) { columns = append(columns, clause.Column{Name: name}) } c.Columns = columns default: return fmt.Errorf("%s: invalid conflict target %v: %w", op, reflect.TypeOf(opts.WithOnConflict.Target), ErrInvalidParameter) } switch opts.WithOnConflict.Action.(type) { case DoNothing: c.DoNothing = true case UpdateAll: c.UpdateAll = true case []ColumnValue: updates := opts.WithOnConflict.Action.([]ColumnValue) set := make(clause.Set, 0, len(updates)) for _, s := range updates { // make sure it's not one of the std immutable columns if contains([]string{"createtime", "publicid"}, strings.ToLower(s.Column)) { return fmt.Errorf("%s: cannot do update on conflict for column %s: %w", op, s.Column, ErrInvalidParameter) } switch sv := s.Value.(type) { case Column: set = append(set, sv.toAssignment(s.Column)) case ExprValue: set = append(set, sv.toAssignment(s.Column)) default: set = append(set, rawAssignment(s.Column, s.Value)) } } c.DoUpdates = set default: return fmt.Errorf("%s: invalid conflict action %v: %w", op, reflect.TypeOf(opts.WithOnConflict.Action), ErrInvalidParameter) } if opts.WithVersion != nil || opts.WithWhereClause != "" { // this is a bit of a hack, but we need to pass in one of the items // to get the where clause since we need to get the gorm Model and // Parse the gorm statement to build the where clause where, args, err := rw.whereClausesFromOpts(ctx, valCreateItems.Index(0).Interface(), opts) if err != nil { return fmt.Errorf("%s: %w", op, err) } whereConditions := db.Statement.BuildCondition(where, args...) c.Where = clause.Where{Exprs: whereConditions} } db = db.Clauses(c) } if opts.WithDebug { db = db.Debug() } if opts.WithTable != "" { db = db.Table(opts.WithTable) } tx := db.CreateInBatches(createItems, opts.WithBatchSize) if tx.Error != nil { return fmt.Errorf("%s: create failed: %w", op, tx.Error) } if opts.WithRowsAffected != nil { *opts.WithRowsAffected = tx.RowsAffected } if tx.RowsAffected > 0 && opts.WithAfterWrite != nil { if err := opts.WithAfterWrite(createItems, int(tx.RowsAffected)); err != nil { return fmt.Errorf("%s: error after write: %w", op, err) } } return nil } func setFieldsToNil(i interface{}, fieldNames []string) { // Note: error cases are not handled _ = Clear(i, fieldNames, 2) } // Clear sets fields in the value pointed to by i to their zero value. // Clear descends i to depth clearing fields at each level. i must be a // pointer to a struct. Cycles in i are not detected. // // A depth of 2 will change i and i's children. A depth of 1 will change i // but no children of i. A depth of 0 will return with no changes to i. func Clear(i interface{}, fields []string, depth int) error { const op = "dbw.Clear" if len(fields) == 0 || depth == 0 { return nil } fm := make(map[string]bool) for _, f := range fields { fm[f] = true } v := reflect.ValueOf(i) switch v.Kind() { case reflect.Ptr: if v.IsNil() || v.Elem().Kind() != reflect.Struct { return fmt.Errorf("%s: %w", op, ErrInvalidParameter) } clear(v, fm, depth) default: return fmt.Errorf("%s: %w", op, ErrInvalidParameter) } return nil } func clear(v reflect.Value, fields map[string]bool, depth int) { if depth == 0 { return } depth-- switch v.Kind() { case reflect.Ptr: clear(v.Elem(), fields, depth+1) case reflect.Struct: typeOfT := v.Type() for i := 0; i < v.NumField(); i++ { f := v.Field(i) if ok := fields[typeOfT.Field(i).Name]; ok { if f.IsValid() && f.CanSet() { f.Set(reflect.Zero(f.Type())) } continue } clear(f, fields, depth) } } }
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package dbw import ( "context" "database/sql" "fmt" "strings" "github.com/hashicorp/go-hclog" "github.com/jackc/pgconn" _ "github.com/jackc/pgx/v5" // required to load postgres drivers "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/logger" ) // DbType defines a database type. It's not an exhaustive list of database // types which can be used by the dbw package, since you can always use // OpenWith(...) to connect to KnownDB types. type DbType int const ( // UnknownDB is an unknown db type UnknownDB DbType = 0 // Postgres is a postgre db type Postgres DbType = 1 // Sqlite is a sqlite db type Sqlite DbType = 2 ) // String provides a string rep of the DbType. func (db DbType) String() string { return [...]string{ "unknown", "postgres", "sqlite", }[db] } // StringToDbType provides a string to type conversion. If the type is known, // then UnknownDB with and error is returned. func StringToDbType(dialect string) (DbType, error) { switch dialect { case "postgres": return Postgres, nil case "sqlite": return Sqlite, nil default: return UnknownDB, fmt.Errorf("%s is an unknown dialect", dialect) } } // DB is a wrapper around whatever is providing the interface for database // operations (typically an ORM). DB uses database/sql to maintain connection // pool. type DB struct { wrapped *gorm.DB } // DbType will return the DbType and raw name of the connection type func (db *DB) DbType() (typ DbType, rawName string, e error) { rawName = db.wrapped.Dialector.Name() typ, _ = StringToDbType(rawName) return typ, rawName, nil } // Debug will enable/disable debug info for the connection func (db *DB) Debug(on bool) { if on { // info level in the Gorm domain which maps to a debug level in this domain db.LogLevel(Info) } else { // the default level in the gorm domain is: error level db.LogLevel(Error) } } // LogLevel defines a log level type LogLevel int const ( // Default specifies the default log level Default LogLevel = iota // Silent is the silent log level Silent // Error is the error log level Error // Warn is the warning log level Warn // Info is the info log level Info ) // LogLevel will set the logging level for the db func (db *DB) LogLevel(l LogLevel) { db.wrapped.Logger = db.wrapped.Logger.LogMode(logger.LogLevel(l)) } // SqlDB returns the underlying sql.DB Note: this makes it possible to do // things like set database/sql connection options like SetMaxIdleConns. If // you're simply setting max/min connections then you should use the // WithMinOpenConnections and WithMaxOpenConnections options when // "opening" the database. // // Care should be take when deciding to use this for basic database operations // like Exec, Query, etc since these functions are already provided by dbw.RW // which provides a layer of encapsulation of the underlying database. func (db *DB) SqlDB(_ context.Context) (*sql.DB, error) { const op = "dbw.(DB).SqlDB" if db.wrapped == nil { return nil, fmt.Errorf("%s: missing underlying database: %w", op, ErrInternal) } return db.wrapped.DB() } // Close the database // // Note: Consider if you need to call Close() on the returned DB. Typically the // answer is no, but there are occasions when it's necessary. See the sql.DB // docs for more information. func (db *DB) Close(ctx context.Context) error { const op = "dbw.(DB).Close" if db.wrapped == nil { return fmt.Errorf("%s: missing underlying database: %w", op, ErrInternal) } underlying, err := db.wrapped.DB() if err != nil { return fmt.Errorf("%s: %w", op, err) } return underlying.Close() } // Open a database connection which is long-lived. The options of // WithLogger, WithLogLevel and WithMaxOpenConnections are supported. // // Note: Consider if you need to call Close() on the returned DB. Typically the // answer is no, but there are occasions when it's necessary. See the sql.DB // docs for more information. func Open(dbType DbType, connectionUrl string, opt ...Option) (*DB, error) { const op = "dbw.Open" if connectionUrl == "" { return nil, fmt.Errorf("%s: missing connection url: %w", op, ErrInvalidParameter) } var dialect gorm.Dialector switch dbType { case Postgres: dialect = postgres.New(postgres.Config{ DSN: connectionUrl, }, ) case Sqlite: dialect = sqlite.Open(connectionUrl) default: return nil, fmt.Errorf("unable to open %s database type", dbType) } db, err := openDialector(dialect, opt...) if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } if dbType == Sqlite { if _, err := New(db).Exec(context.Background(), "PRAGMA foreign_keys=ON", nil); err != nil { return nil, fmt.Errorf("%s: unable to enable sqlite foreign keys: %w", op, err) } } return db, nil } // Dialector provides a set of functions the database dialect must satisfy to // be used with OpenWith(...) // It's a simple wrapper of the gorm.Dialector and provides the ability to open // any support gorm dialect driver. type Dialector interface { gorm.Dialector } // OpenWith will open a database connection using a Dialector which is // long-lived. The options of WithLogger, WithLogLevel and // WithMaxOpenConnections are supported. // // Note: Consider if you need to call Close() on the returned DB. Typically the // answer is no, but there are occasions when it's necessary. See the sql.DB // docs for more information. func OpenWith(dialector Dialector, opt ...Option) (*DB, error) { return openDialector(dialector, opt...) } func openDialector(dialect gorm.Dialector, opt ...Option) (*DB, error) { db, err := gorm.Open(dialect, &gorm.Config{}) if err != nil { return nil, fmt.Errorf("unable to open database: %w", err) } if strings.ToLower(dialect.Name()) == "sqlite" { if err := db.Exec("PRAGMA foreign_keys=ON", nil).Error; err != nil { return nil, fmt.Errorf("unable to enable sqlite foreign keys: %w", err) } } opts := GetOpts(opt...) if opts.WithLogger != nil { newLogger := logger.New( getGormLogger(opts.WithLogger), logger.Config{ LogLevel: logger.LogLevel(opts.withLogLevel), // Log level Colorful: false, // Disable color }, ) db = db.Session(&gorm.Session{Logger: newLogger}) } if opts.WithMaxOpenConnections > 0 { if opts.WithMinOpenConnections > 0 && (opts.WithMaxOpenConnections < opts.WithMinOpenConnections) { return nil, fmt.Errorf("unable to create db object with dialect %s: %s", dialect, fmt.Sprintf("max_open_connections must be unlimited by setting 0 or at least %d", opts.WithMinOpenConnections)) } underlyingDB, err := db.DB() if err != nil { return nil, fmt.Errorf("unable retrieve db: %w", err) } underlyingDB.SetMaxOpenConns(opts.WithMaxOpenConnections) } return &DB{wrapped: db}, nil } type gormLogger struct { logger hclog.Logger } func (g gormLogger) Printf(_ string, values ...interface{}) { if len(values) > 1 { switch values[1].(type) { case *pgconn.PgError: g.logger.Trace("error from database adapter", "location", values[0], "error", values[1]) } } } func getGormLogger(log hclog.Logger) gormLogger { return gormLogger{logger: log} }
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package dbw import ( "context" "fmt" "reflect" ) // Delete a resource in the db with options: WithWhere, WithDebug, WithTable, // and WithVersion. WithWhere and WithVersion allows specifying a additional // constraints on the operation in addition to the PKs. Delete returns the // number of rows deleted and any errors. func (rw *RW) Delete(ctx context.Context, i interface{}, opt ...Option) (int, error) { const op = "dbw.Delete" if rw.underlying == nil { return noRowsAffected, fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) } if isNil(i) { return noRowsAffected, fmt.Errorf("%s: missing interface: %w", op, ErrInvalidParameter) } if err := raiseErrorOnHooks(i); err != nil { return noRowsAffected, fmt.Errorf("%s: %w", op, err) } opts := GetOpts(opt...) mDb := rw.underlying.wrapped.Model(i) err := mDb.Statement.Parse(i) if err == nil && mDb.Statement.Schema == nil { return noRowsAffected, fmt.Errorf("%s: (internal error) unable to parse stmt: %w", op, ErrUnknown) } reflectValue := reflect.Indirect(reflect.ValueOf(i)) for _, pf := range mDb.Statement.Schema.PrimaryFields { if _, isZero := pf.ValueOf(ctx, reflectValue); isZero { return noRowsAffected, fmt.Errorf("%s: primary key %s is not set: %w", op, pf.Name, ErrInvalidParameter) } } if opts.WithBeforeWrite != nil { if err := opts.WithBeforeWrite(i); err != nil { return noRowsAffected, fmt.Errorf("%s: error before write: %w", op, err) } } db := rw.underlying.wrapped.WithContext(ctx) if opts.WithVersion != nil || opts.WithWhereClause != "" { where, args, err := rw.whereClausesFromOpts(ctx, i, opts) if err != nil { return noRowsAffected, fmt.Errorf("%s: %w", op, err) } db = db.Where(where, args...) } if opts.WithDebug { db = db.Debug() } if opts.WithTable != "" { db = db.Table(opts.WithTable) } db = db.Delete(i) if db.Error != nil { return noRowsAffected, fmt.Errorf("%s: %w", op, db.Error) } rowsDeleted := int(db.RowsAffected) if rowsDeleted > 0 && opts.WithAfterWrite != nil { if err := opts.WithAfterWrite(i, rowsDeleted); err != nil { return rowsDeleted, fmt.Errorf("%s: error after write: %w", op, err) } } return rowsDeleted, nil } // DeleteItems will delete multiple items of the same type. Options supported: // WithWhereClause, WithDebug, WithTable func (rw *RW) DeleteItems(ctx context.Context, deleteItems interface{}, opt ...Option) (int, error) { const op = "dbw.DeleteItems" switch { case rw.underlying == nil: return noRowsAffected, fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) case isNil(deleteItems): return noRowsAffected, fmt.Errorf("%s: no interfaces to delete: %w", op, ErrInvalidParameter) } valDeleteItems := reflect.ValueOf(deleteItems) switch { case valDeleteItems.Kind() != reflect.Slice: return noRowsAffected, fmt.Errorf("%s: not a slice: %w", op, ErrInvalidParameter) case valDeleteItems.Len() == 0: return noRowsAffected, fmt.Errorf("%s: missing items: %w", op, ErrInvalidParameter) } if err := raiseErrorOnHooks(deleteItems); err != nil { return noRowsAffected, fmt.Errorf("%s: %w", op, err) } opts := GetOpts(opt...) switch { case opts.WithLookup: return noRowsAffected, fmt.Errorf("%s: with lookup not a supported option: %w", op, ErrInvalidParameter) case opts.WithVersion != nil: return noRowsAffected, fmt.Errorf("%s: with version is not a supported option: %w", op, ErrInvalidParameter) } // we need to dig out the stmt so in just a sec we can make sure the PKs are // set for all the items, so we'll just use the first item to do so. mDb := rw.underlying.wrapped.Model(valDeleteItems.Index(0).Interface()) err := mDb.Statement.Parse(valDeleteItems.Index(0).Interface()) switch { case err != nil: return noRowsAffected, fmt.Errorf("%s: (internal error) error parsing stmt: %w", op, err) case err == nil && mDb.Statement.Schema == nil: return noRowsAffected, fmt.Errorf("%s: (internal error) unable to parse stmt: %w", op, ErrUnknown) } // verify that deleteItems are all the same type, among a myriad of // other things on the set of items var foundType reflect.Type for i := 0; i < valDeleteItems.Len(); i++ { if i == 0 { foundType = reflect.TypeOf(valDeleteItems.Index(i).Interface()) } currentType := reflect.TypeOf(valDeleteItems.Index(i).Interface()) switch { case isNil(valDeleteItems.Index(i).Interface()) || currentType == nil: return noRowsAffected, fmt.Errorf("%s: unable to determine type of item %d: %w", op, i, ErrInvalidParameter) case foundType != currentType: return noRowsAffected, fmt.Errorf("%s: items contain disparate types. item %d is not a %s: %w", op, i, foundType.Name(), ErrInvalidParameter) } if opts.WithWhereClause == "" { // make sure the PK is set for the current item reflectValue := reflect.Indirect(reflect.ValueOf(valDeleteItems.Index(i).Interface())) for _, pf := range mDb.Statement.Schema.PrimaryFields { if _, isZero := pf.ValueOf(ctx, reflectValue); isZero { return noRowsAffected, fmt.Errorf("%s: primary key %s is not set: %w", op, pf.Name, ErrInvalidParameter) } } } } if opts.WithBeforeWrite != nil { if err := opts.WithBeforeWrite(deleteItems); err != nil { return noRowsAffected, fmt.Errorf("%s: error before write: %w", op, err) } } db := rw.underlying.wrapped.WithContext(ctx) if opts.WithDebug { db = db.Debug() } if opts.WithWhereClause != "" { where, args, err := rw.whereClausesFromOpts(ctx, valDeleteItems.Index(0).Interface(), opts) if err != nil { return noRowsAffected, fmt.Errorf("%s: %w", op, err) } db = db.Where(where, args...) } switch { case opts.WithTable != "": db = db.Table(opts.WithTable) default: tabler, ok := valDeleteItems.Index(0).Interface().(tableNamer) if ok { db = db.Table(tabler.TableName()) } } db = db.Delete(deleteItems) if db.Error != nil { return noRowsAffected, fmt.Errorf("%s: %w", op, db.Error) } rowsDeleted := int(db.RowsAffected) if rowsDeleted > 0 && opts.WithAfterWrite != nil { if err := opts.WithAfterWrite(deleteItems, int(rowsDeleted)); err != nil { return rowsDeleted, fmt.Errorf("%s: error after write: %w", op, err) } } return rowsDeleted, nil } type tableNamer interface { TableName() string }
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package dbw import ( "context" "fmt" "time" ) // DoTx will wrap the Handler func passed within a transaction with retries // you should ensure that any objects written to the db in your TxHandler are retryable, which // means that the object may be sent to the db several times (retried), so // things like the primary key may need to be reset before retry. func (rw *RW) DoTx(ctx context.Context, retryErrorsMatchingFn func(error) bool, retries uint, backOff Backoff, handler TxHandler) (RetryInfo, error) { const op = "dbw.DoTx" if rw.underlying == nil { return RetryInfo{}, fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) } if backOff == nil { return RetryInfo{}, fmt.Errorf("%s: missing backoff: %w", op, ErrInvalidParameter) } if handler == nil { return RetryInfo{}, fmt.Errorf("%s: missing handler: %w", op, ErrInvalidParameter) } if retryErrorsMatchingFn == nil { return RetryInfo{}, fmt.Errorf("%s: missing retry errors matching function: %w", op, ErrInvalidParameter) } info := RetryInfo{} for attempts := uint(1); ; attempts++ { if attempts > retries+1 { return info, fmt.Errorf("%s: too many retries: %d of %d: %w", op, attempts-1, retries+1, ErrMaxRetries) } // step one of this, start a transaction... newTx := rw.underlying.wrapped.WithContext(ctx) newTx = newTx.Begin() newRW := &RW{underlying: &DB{newTx}} if err := handler(newRW, newRW); err != nil { if err := newTx.Rollback().Error; err != nil { return info, fmt.Errorf("%s: %w", op, err) } if retry := retryErrorsMatchingFn(err); retry { d := backOff.Duration(attempts) info.Retries++ info.Backoff = info.Backoff + d select { case <-ctx.Done(): return info, fmt.Errorf("%s: cancelled: %w", op, err) case <-time.After(d): continue } } return info, fmt.Errorf("%s: %w", op, err) } if err := newTx.Commit().Error; err != nil { if err := newTx.Rollback().Error; err != nil { return info, fmt.Errorf("%s: %w", op, err) } return info, fmt.Errorf("%s: %w", op, err) } return info, nil // it all worked!!! } }
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package dbw import ( "bytes" "fmt" "strings" "github.com/hashicorp/go-secure-stdlib/base62" "golang.org/x/crypto/blake2b" ) // NewId creates a new random base62 ID with the provided prefix with an // underscore delimiter func NewId(prefix string, opt ...Option) (string, error) { const op = "dbw.NewId" if prefix == "" { return "", fmt.Errorf("%s: missing prefix: %w", op, ErrInvalidParameter) } var publicId string var err error opts := GetOpts(opt...) if len(opts.WithPrngValues) > 0 { sum := blake2b.Sum256([]byte(strings.Join(opts.WithPrngValues, "|"))) reader := bytes.NewReader(sum[0:]) publicId, err = base62.RandomWithReader(10, reader) } else { publicId, err = base62.Random(10) } if err != nil { return "", fmt.Errorf("%s: unable to generate id: %w", op, ErrInternal) } return fmt.Sprintf("%s_%s", prefix, publicId), nil }
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package dbw import ( "context" "fmt" "gorm.io/gorm" ) // LookupBy will lookup a resource by it's primary keys, which must be // unique. If the resource implements either ResourcePublicIder or // ResourcePrivateIder interface, then they are used as the resource's // primary key for lookup. Otherwise, the resource tags are used to // determine it's primary key(s) for lookup. The WithDebug and WithTable // options are supported. func (rw *RW) LookupBy(ctx context.Context, resourceWithIder interface{}, opt ...Option) error { const op = "dbw.LookupById" if rw.underlying == nil { return fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) } if err := raiseErrorOnHooks(resourceWithIder); err != nil { return fmt.Errorf("%s: %w", op, err) } if err := validateResourcesInterface(resourceWithIder); err != nil { return fmt.Errorf("%s: %w", op, err) } where, keys, err := rw.primaryKeysWhere(ctx, resourceWithIder) if err != nil { return fmt.Errorf("%s: %w", op, err) } opts := GetOpts(opt...) db := rw.underlying.wrapped.WithContext(ctx) if opts.WithTable != "" { db = db.Table(opts.WithTable) } if opts.WithDebug { db = db.Debug() } rw.clearDefaultNullResourceFields(ctx, resourceWithIder) if err := db.Where(where, keys...).First(resourceWithIder).Error; err != nil { if err == gorm.ErrRecordNotFound { return fmt.Errorf("%s: %w", op, ErrRecordNotFound) } return fmt.Errorf("%s: %w", op, err) } return nil } // LookupByPublicId will lookup resource by its public_id, which must be unique. // The WithTable option is supported. func (rw *RW) LookupByPublicId(ctx context.Context, resource ResourcePublicIder, opt ...Option) error { return rw.LookupBy(ctx, resource, opt...) } func (rw *RW) lookupAfterWrite(ctx context.Context, i interface{}, opt ...Option) error { const op = "dbw.lookupAfterWrite" opts := GetOpts(opt...) withLookup := opts.WithLookup if err := raiseErrorOnHooks(i); err != nil { return fmt.Errorf("%s: %w", op, err) } if !withLookup { return nil } if err := rw.LookupBy(ctx, i, opt...); err != nil { return fmt.Errorf("%s: %w", op, err) } return nil }
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package dbw import ( "github.com/hashicorp/go-hclog" ) // GetOpts - iterate the inbound Options and return a struct. func GetOpts(opt ...Option) Options { opts := getDefaultOptions() for _, o := range opt { if o != nil { o(&opts) } } return opts } // Option - how Options are passed as arguments. type Option func(*Options) // Options - how Options are represented which have been set via an Option // function. Use GetOpts(...) to populated this struct with the options that // have been specified for an operation. All option fields are exported so // they're available for use by other packages. type Options struct { // WithBeforeWrite provides and option to provide a func to be called before a // write operation. The i interface{} passed at runtime will be the resource(s) // being written. WithBeforeWrite func(i interface{}) error // WithAfterWrite provides and option to provide a func to be called after a // write operation. The i interface{} passed at runtime will be the resource(s) // being written. WithAfterWrite func(i interface{}, rowsAffected int) error // WithLookup enables a lookup after a write operation. WithLookup bool // WithLimit provides an option to provide a limit. Intentionally allowing // negative integers. If WithLimit < 0, then unlimited results are returned. // If WithLimit == 0, then default limits are used for results (see DefaultLimit // const). WithLimit int // WithFieldMaskPaths provides an option to provide field mask paths for update // operations. WithFieldMaskPaths []string // WithNullPaths provides an option to provide null paths for update // operations. WithNullPaths []string // WithVersion provides an option version number for update operations. Using // this option requires that your resource has a version column that's // incremented for every successful update operation. Version provides an // optimistic locking mechanism for write operations. WithVersion *uint32 WithSkipVetForWrite bool // WithWhereClause provides an option to provide a where clause for an // operation. WithWhereClause string // WithWhereClauseArgs provides an option to provide a where clause arguments for an // operation. WithWhereClauseArgs []interface{} // WithOrder provides an option to provide an order when searching and looking // up. WithOrder string // WithPrngValues provides an option to provide values to seed an PRNG when generating IDs WithPrngValues []string // WithLogger specifies an optional hclog to use for db operations. It's only // valid for Open(..) and OpenWith(...) WithLogger hclog.Logger // WithMinOpenConnections specifies and optional min open connections for the // database. A value of zero means that there is no min. WithMaxOpenConnections int // WithMaxOpenConnections specifies and optional max open connections for the // database. A value of zero equals unlimited connections WithMinOpenConnections int // WithDebug indicates that the given operation should invoke debug output // mode WithDebug bool // WithOnConflict specifies an optional on conflict criteria which specify // alternative actions to take when an insert results in a unique constraint or // exclusion constraint error WithOnConflict *OnConflict // WithRowsAffected specifies an option for returning the rows affected // and typically used with "bulk" write operations. WithRowsAffected *int64 // WithTable specifies an option for setting a table name to use for the // operation. WithTable string // WithBatchSize specifies an option for setting the batch size for bulk // operations. If WithBatchSize == 0, then the default batch size is used. WithBatchSize int withLogLevel LogLevel } func getDefaultOptions() Options { return Options{ WithFieldMaskPaths: []string{}, WithNullPaths: []string{}, WithBatchSize: DefaultBatchSize, withLogLevel: Error, } } // WithBeforeWrite provides and option to provide a func to be called before a // write operation. The i interface{} passed at runtime will be the resource(s) // being written. func WithBeforeWrite(fn func(i interface{}) error) Option { return func(o *Options) { o.WithBeforeWrite = fn } } // WithAfterWrite provides and option to provide a func to be called after a // write operation. The i interface{} passed at runtime will be the resource(s) // being written. func WithAfterWrite(fn func(i interface{}, rowsAffected int) error) Option { return func(o *Options) { o.WithAfterWrite = fn } } // WithLookup enables a lookup after a write operation. func WithLookup(enable bool) Option { return func(o *Options) { o.WithLookup = enable } } // WithFieldMaskPaths provides an option to provide field mask paths for update // operations. func WithFieldMaskPaths(paths []string) Option { return func(o *Options) { o.WithFieldMaskPaths = paths } } // WithNullPaths provides an option to provide null paths for update operations. func WithNullPaths(paths []string) Option { return func(o *Options) { o.WithNullPaths = paths } } // WithLimit provides an option to provide a limit. Intentionally allowing // negative integers. If WithLimit < 0, then unlimited results are returned. // If WithLimit == 0, then default limits are used for results (see DefaultLimit // const). func WithLimit(limit int) Option { return func(o *Options) { o.WithLimit = limit } } // WithVersion provides an option version number for update operations. Using // this option requires that your resource has a version column that's // incremented for every successful update operation. Version provides an // optimistic locking mechanism for write operations. func WithVersion(version *uint32) Option { return func(o *Options) { o.WithVersion = version } } // WithSkipVetForWrite provides an option to allow skipping vet checks to allow // testing lower-level SQL triggers and constraints func WithSkipVetForWrite(enable bool) Option { return func(o *Options) { o.WithSkipVetForWrite = enable } } // WithWhere provides an option to provide a where clause with arguments for an // operation. func WithWhere(whereClause string, args ...interface{}) Option { return func(o *Options) { o.WithWhereClause = whereClause o.WithWhereClauseArgs = append(o.WithWhereClauseArgs, args...) } } // WithOrder provides an option to provide an order when searching and looking // up. func WithOrder(withOrder string) Option { return func(o *Options) { o.WithOrder = withOrder } } // WithPrngValues provides an option to provide values to seed an PRNG when generating IDs func WithPrngValues(withPrngValues []string) Option { return func(o *Options) { o.WithPrngValues = withPrngValues } } // WithLogger specifies an optional hclog to use for db operations. It's only // valid for Open(..) and OpenWith(...) func WithLogger(l hclog.Logger) Option { return func(o *Options) { o.WithLogger = l } } // WithMaxOpenConnections specifies and optional max open connections for the // database. A value of zero equals unlimited connections func WithMaxOpenConnections(max int) Option { return func(o *Options) { o.WithMaxOpenConnections = max } } // WithMinOpenConnections specifies and optional min open connections for the // database. A value of zero means that there is no min. func WithMinOpenConnections(max int) Option { return func(o *Options) { o.WithMinOpenConnections = max } } // WithDebug specifies the given operation should invoke debug mode for the // database output func WithDebug(with bool) Option { return func(o *Options) { o.WithDebug = with } } // WithOnConflict specifies an optional on conflict criteria which specify // alternative actions to take when an insert results in a unique constraint or // exclusion constraint error func WithOnConflict(onConflict *OnConflict) Option { return func(o *Options) { o.WithOnConflict = onConflict } } // WithReturnRowsAffected specifies an option for returning the rows affected // and typically used with "bulk" write operations. func WithReturnRowsAffected(rowsAffected *int64) Option { return func(o *Options) { o.WithRowsAffected = rowsAffected } } // WithTable specifies an option for setting a table name to use for the // operation. func WithTable(name string) Option { return func(o *Options) { o.WithTable = name } } // WithLogLevel specifies an option for setting the log level func WithLogLevel(l LogLevel) Option { return func(o *Options) { o.withLogLevel = l } } // WithBatchSize specifies an option for setting the batch size for bulk // operations like CreateItems. If WithBatchSize == 0, the default batch size is // used (see DefaultBatchSize const). func WithBatchSize(size int) Option { return func(o *Options) { o.WithBatchSize = size } }
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package dbw import ( "context" "database/sql" "fmt" ) // Query will run the raw query and return the *sql.Rows results. Query will // operate within the context of any ongoing transaction for the Reader. The // caller must close the returned *sql.Rows. Query can/should be used in // combination with ScanRows. The WithDebug option is supported. func (rw *RW) Query(ctx context.Context, sql string, values []interface{}, opt ...Option) (*sql.Rows, error) { const op = "dbw.Query" if rw.underlying == nil { return nil, fmt.Errorf("%s: missing underlying db: %w", op, ErrInternal) } if sql == "" { return nil, fmt.Errorf("%s: missing sql: %w", op, ErrInvalidParameter) } opts := GetOpts(opt...) db := rw.underlying.wrapped.WithContext(ctx) if opts.WithDebug { db = db.Debug() } db = db.Raw(sql, values...) if db.Error != nil { return nil, fmt.Errorf("%s: %w", op, db.Error) } return db.Rows() } // ScanRows will scan the rows into the interface func (rw *RW) ScanRows(rows *sql.Rows, result interface{}) error { const op = "dbw.ScanRows" if rw.underlying == nil { return fmt.Errorf("%s: missing underlying db: %w", op, ErrInternal) } if rows == nil { return fmt.Errorf("%s: missing rows: %w", op, ErrInvalidParameter) } if isNil(result) { return fmt.Errorf("%s: missing result: %w", op, ErrInvalidParameter) } return rw.underlying.wrapped.ScanRows(rows, result) }
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package dbw import ( "context" "fmt" "reflect" "strings" "gorm.io/gorm" "gorm.io/gorm/callbacks" ) const ( noRowsAffected = 0 // DefaultLimit is the default for search results when no limit is specified // via the WithLimit(...) option DefaultLimit = 10000 ) // RW uses a DB as a connection for it's read/write operations. This is // basically the primary type for the package's operations. type RW struct { underlying *DB } // ensure that RW implements the interfaces of: Reader and Writer var ( _ Reader = (*RW)(nil) _ Writer = (*RW)(nil) ) // New creates a new RW using an open DB. Note: there can by many RWs that share // the same DB, since the DB manages the connection pool. func New(underlying *DB) *RW { return &RW{underlying: underlying} } // DB returns the underlying DB func (rw *RW) DB() *DB { return rw.underlying } // Exec will execute the sql with the values as parameters. The int returned // is the number of rows affected by the sql. The WithDebug option is supported. func (rw *RW) Exec(ctx context.Context, sql string, values []interface{}, opt ...Option) (int, error) { const op = "dbw.Exec" if rw.underlying == nil { return 0, fmt.Errorf("%s: missing underlying db: %w", op, ErrInternal) } if sql == "" { return noRowsAffected, fmt.Errorf("%s: missing sql: %w", op, ErrInvalidParameter) } opts := GetOpts(opt...) db := rw.underlying.wrapped.WithContext(ctx) if opts.WithDebug { db = db.Debug() } db = db.Exec(sql, values...) if db.Error != nil { return noRowsAffected, fmt.Errorf("%s: %w", op, db.Error) } return int(db.RowsAffected), nil } func (rw *RW) primaryFieldsAreZero(ctx context.Context, i interface{}) ([]string, bool, error) { const op = "dbw.primaryFieldsAreZero" var fieldNames []string tx := rw.underlying.wrapped.Model(i) if err := tx.Statement.Parse(i); err != nil { return nil, false, fmt.Errorf("%s: %w", op, ErrInvalidParameter) } for _, f := range tx.Statement.Schema.PrimaryFields { if f.PrimaryKey { if _, isZero := f.ValueOf(ctx, reflect.ValueOf(i)); isZero { fieldNames = append(fieldNames, f.Name) } } } return fieldNames, len(fieldNames) > 0, nil } func isNil(i interface{}) bool { if i == nil { return true } switch reflect.TypeOf(i).Kind() { case reflect.Ptr, reflect.Map, reflect.Chan, reflect.Slice: return reflect.ValueOf(i).IsNil() } return false } func contains(ss []string, t string) bool { for _, s := range ss { if strings.EqualFold(s, t) { return true } } return false } func validateResourcesInterface(resources interface{}) error { const op = "dbw.validateResourcesInterface" vo := reflect.ValueOf(resources) if vo.Kind() != reflect.Ptr { return fmt.Errorf("%s: interface parameter must to be a pointer: %w", op, ErrInvalidParameter) } e := vo.Elem() if e.Kind() == reflect.Slice { if e.Type().Elem().Kind() != reflect.Ptr { return fmt.Errorf("%s: interface parameter is a slice, but the elements of the slice are not pointers: %w", op, ErrInvalidParameter) } } return nil } func raiseErrorOnHooks(i interface{}) error { const op = "dbw.raiseErrorOnHooks" v := i valOf := reflect.ValueOf(i) if valOf.Kind() == reflect.Slice { if valOf.Len() == 0 { return nil } v = valOf.Index(0).Interface() } switch v.(type) { case // create hooks callbacks.BeforeCreateInterface, callbacks.AfterCreateInterface, callbacks.BeforeSaveInterface, callbacks.AfterSaveInterface, // update hooks callbacks.BeforeUpdateInterface, callbacks.AfterUpdateInterface, // delete hooks callbacks.BeforeDeleteInterface, callbacks.AfterDeleteInterface, // find hooks callbacks.AfterFindInterface: return fmt.Errorf("%s: gorm callback/hooks are not supported: %w", op, ErrInvalidParameter) } return nil } // IsTx returns true if there's an existing transaction in progress func (rw *RW) IsTx() bool { if rw.underlying == nil { return false } switch rw.underlying.wrapped.Statement.ConnPool.(type) { case gorm.TxBeginner, gorm.ConnPoolBeginner: return false default: return true } } func (rw *RW) whereClausesFromOpts(_ context.Context, i interface{}, opts Options) (string, []interface{}, error) { const op = "dbw.whereClausesFromOpts" var where []string var args []interface{} if opts.WithVersion != nil { if *opts.WithVersion == 0 { return "", nil, fmt.Errorf("%s: with version option is zero: %w", op, ErrInvalidParameter) } mDb := rw.underlying.wrapped.Model(i) err := mDb.Statement.Parse(i) if err != nil && mDb.Statement.Schema == nil { return "", nil, fmt.Errorf("%s: (internal error) unable to parse stmt: %w", op, ErrUnknown) } if !contains(mDb.Statement.Schema.DBNames, "version") { return "", nil, fmt.Errorf("%s: %s does not have a version field: %w", op, mDb.Statement.Schema.Table, ErrInvalidParameter) } if opts.WithOnConflict != nil { // on conflict clauses requires the version to be qualified with a // table name var tableName string switch { case opts.WithTable != "": tableName = opts.WithTable default: tableName = mDb.Statement.Schema.Table } where = append(where, fmt.Sprintf("%s.version = ?", tableName)) // we need to include the table name because of "on conflict" use cases } else { where = append(where, "version = ?") } args = append(args, opts.WithVersion) } if opts.WithWhereClause != "" { where, args = append(where, opts.WithWhereClause), append(args, opts.WithWhereClauseArgs...) } return strings.Join(where, " and "), args, nil } // clearDefaultNullResourceFields will clear fields in the resource which are // defaulted to a null value. This addresses the unfixed issue in gorm: // https://github.com/go-gorm/gorm/issues/6351 func (rw *RW) clearDefaultNullResourceFields(ctx context.Context, i interface{}) error { const op = "dbw.ClearResourceFields" stmt := rw.underlying.wrapped.Model(i).Statement if err := stmt.Parse(i); err != nil { return fmt.Errorf("%s: %w", op, err) } v := reflect.ValueOf(i) for _, f := range stmt.Schema.Fields { switch { case f.PrimaryKey: // seems a bit redundant, with the test for null, but it's very // important to not clear the primary fields, so we'll make an // explicit test continue case !f.Updatable: // well, based on the gorm tags it's a read-only field, so we're done. continue case !strings.EqualFold(f.DefaultValue, "null"): continue default: _, isZero := f.ValueOf(ctx, v) if isZero { continue } if err := f.Set(stmt.Context, v, f.DefaultValueInterface); err != nil { return fmt.Errorf("%s: unable to set value of non-zero field: %w", op, err) } } } return nil } func (rw *RW) primaryKeysWhere(ctx context.Context, i interface{}) (string, []interface{}, error) { const op = "dbw.primaryKeysWhere" var fieldNames []string var fieldValues []interface{} tx := rw.underlying.wrapped.Model(i) if err := tx.Statement.Parse(i); err != nil { return "", nil, fmt.Errorf("%s: %w", op, err) } switch resourceType := i.(type) { case ResourcePublicIder: if resourceType.GetPublicId() == "" { return "", nil, fmt.Errorf("%s: missing primary key: %w", op, ErrInvalidParameter) } fieldValues = []interface{}{resourceType.GetPublicId()} fieldNames = []string{"public_id"} case ResourcePrivateIder: if resourceType.GetPrivateId() == "" { return "", nil, fmt.Errorf("%s: missing primary key: %w", op, ErrInvalidParameter) } fieldValues = []interface{}{resourceType.GetPrivateId()} fieldNames = []string{"private_id"} default: v := reflect.ValueOf(i) for _, f := range tx.Statement.Schema.PrimaryFields { if f.PrimaryKey { val, isZero := f.ValueOf(ctx, v) if isZero { return "", nil, fmt.Errorf("%s: primary field %s is zero: %w", op, f.Name, ErrInvalidParameter) } fieldNames = append(fieldNames, f.DBName) fieldValues = append(fieldValues, val) } } } if len(fieldNames) == 0 { return "", nil, fmt.Errorf("%s: no primary key(s) for %t: %w", op, i, ErrInvalidParameter) } clauses := make([]string, 0, len(fieldNames)) for _, col := range fieldNames { clauses = append(clauses, fmt.Sprintf("%s = ?", col)) } return strings.Join(clauses, " and "), fieldValues, nil } // LookupWhere will lookup the first resource using a where clause with // parameters (it only returns the first one). Supports WithDebug, and // WithTable options. func (rw *RW) LookupWhere(ctx context.Context, resource interface{}, where string, args []interface{}, opt ...Option) error { const op = "dbw.LookupWhere" if rw.underlying == nil { return fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) } if err := validateResourcesInterface(resource); err != nil { return fmt.Errorf("%s: %w", op, err) } if err := raiseErrorOnHooks(resource); err != nil { return fmt.Errorf("%s: %w", op, err) } opts := GetOpts(opt...) db := rw.underlying.wrapped.WithContext(ctx) if opts.WithTable != "" { db = db.Table(opts.WithTable) } if opts.WithDebug { db = db.Debug() } if err := db.Where(where, args...).First(resource).Error; err != nil { if err == gorm.ErrRecordNotFound { return fmt.Errorf("%s: %w", op, ErrRecordNotFound) } return fmt.Errorf("%s: %w", op, err) } return nil } // SearchWhere will search for all the resources it can find using a where // clause with parameters. An error will be returned if args are provided without a // where clause. // // Supports WithTable and WithLimit options. If WithLimit < 0, then unlimited results are returned. // If WithLimit == 0, then default limits are used for results. // Supports the WithOrder, WithTable, and WithDebug options. func (rw *RW) SearchWhere(ctx context.Context, resources interface{}, where string, args []interface{}, opt ...Option) error { const op = "dbw.SearchWhere" opts := GetOpts(opt...) if rw.underlying == nil { return fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) } if where == "" && len(args) > 0 { return fmt.Errorf("%s: args provided with empty where: %w", op, ErrInvalidParameter) } if err := raiseErrorOnHooks(resources); err != nil { return fmt.Errorf("%s: %w", op, err) } if err := validateResourcesInterface(resources); err != nil { return fmt.Errorf("%s: %w", op, err) } var err error db := rw.underlying.wrapped.WithContext(ctx) if opts.WithOrder != "" { db = db.Order(opts.WithOrder) } if opts.WithDebug { db = db.Debug() } if opts.WithTable != "" { db = db.Table(opts.WithTable) } // Perform limiting switch { case opts.WithLimit < 0: // any negative number signals unlimited results case opts.WithLimit == 0: // zero signals the default value and default limits db = db.Limit(DefaultLimit) default: db = db.Limit(opts.WithLimit) } if where != "" { db = db.Where(where, args...) } // Perform the query err = db.Find(resources).Error if err != nil { // searching with a slice parameter does not return a gorm.ErrRecordNotFound return fmt.Errorf("%s: %w", op, err) } return nil } func (rw *RW) Dialect() (_ DbType, rawName string, _ error) { return rw.underlying.DbType() }
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package dbw import ( "context" "database/sql" "fmt" "os" "strings" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/xo/dburl" pgDriver "gorm.io/driver/postgres" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/gorm/logger" ) // TestSetup is typically called before starting a test and will setup the // database for the test (initialize the database one-time). Do not close the // returned db. Supported test options: WithDebug, WithTestDialect, // WithTestDatabaseUrl, WithTestMigration and WithTestMigrationUsingDB. func TestSetup(t *testing.T, opt ...TestOption) (*DB, string) { require := require.New(t) var url string var err error InitNonUpdatableFields([]string{"CreateTime", "UpdateTime", "PublicId"}) InitNonCreatableFields([]string{"CreateTime", "UpdateTime"}) ctx := context.Background() opts := getTestOpts(opt...) switch strings.ToLower(os.Getenv("DB_DIALECT")) { case "postgres": opts.withDialect = Postgres.String() case "sqlite": opts.withDialect = Sqlite.String() default: if opts.withDialect == "" { opts.withDialect = Sqlite.String() } } if url := os.Getenv("DB_DSN"); url != "" { opts.withTestDatabaseUrl = url } switch { case opts.withDialect == Postgres.String() && opts.withTestDatabaseUrl == "": t.Fatal("missing postgres test db url") case opts.withDialect == Sqlite.String() && opts.withTestDatabaseUrl == "": url = "file::memory:" // just using a temp in-memory sqlite database default: url = opts.withTestDatabaseUrl } switch opts.withDialect { case Postgres.String(): u, err := dburl.Parse(opts.withTestDatabaseUrl) require.NoError(err) db, err := Open(Postgres, u.DSN) require.NoError(err) rw := New(db) tmpDbName, err := NewId("go_db_tmp") tmpDbName = strings.ToLower(tmpDbName) require.NoError(err) _, err = rw.Exec(ctx, fmt.Sprintf(`create database "%s"`, tmpDbName), nil) require.NoError(err) t.Cleanup(func() { _, err = rw.Exec(ctx, `select pg_terminate_backend(pid) from pg_stat_activity where datname = ? and pid <> pg_backend_pid()`, []interface{}{tmpDbName}) assert.NoError(t, err) _, err = rw.Exec(ctx, fmt.Sprintf(`drop database %s`, tmpDbName), nil) assert.NoError(t, err) }) _, err = rw.Exec(ctx, fmt.Sprintf(`grant all privileges on database %s to %s`, tmpDbName, u.User.Username()), nil) require.NoError(err) namesSegs := strings.Split(strings.TrimPrefix(u.Path, "/"), "?") require.Truef(len(namesSegs) > 0, "couldn't determine db name from URL") namesSegs[0] = tmpDbName u.Path = strings.Join(namesSegs, "?") url, _, err = dburl.GenPostgres(u) require.NoError(err) } dbType, err := StringToDbType(opts.withDialect) require.NoError(err) db, err := Open(dbType, url) require.NoError(err) db.wrapped.Logger.LogMode(logger.Error) t.Cleanup(func() { assert.NoError(t, db.Close(ctx), "Got error closing db.") }) if opts.withTestDebug || strings.ToLower(os.Getenv("DEBUG")) == "true" { db.Debug(true) } // we're only going to run one set of migrations. Either one of the // migration functions passed in as an option or the default // TestCreateTables(...) switch { case opts.withTestMigration != nil: err = opts.withTestMigration(ctx, opts.withDialect, url) require.NoError(err) case opts.withTestMigrationUsingDb != nil: var rawDB *sql.DB switch opts.withDialect { case Sqlite.String(): // we need to special case handle sqlite because we may want to run // the migration on an in-memory database that isn't shared // "file::memory:" or ":memory:" so we have to be sure that we don't // open a new connection, which would create a new in-memory db vs // using the existing one already opened in this function a few // lines above with: db, err := Open(dbType, url) // see: https://www.sqlite.org/inmemorydb.html // // luckily, the gorm sqlite ConnPool is the existing opened *sql.DB, // so we can just run the migration on that conn var ok bool rawDB, ok = db.wrapped.ConnPool.(*sql.DB) require.True(ok, "expected the gorm ConnPool to be an *sql.DB") default: var err error rawDB, err = db.wrapped.DB() if err != nil { require.NoError(err) } } err = opts.withTestMigrationUsingDb(ctx, rawDB) require.NoError(err) default: TestCreateTables(t, db) } return db, url } // TestSetupWithMock will return a test DB and an associated Sqlmock which can // be used to mock out the db responses. func TestSetupWithMock(t *testing.T) (*DB, sqlmock.Sqlmock) { t.Helper() require := require.New(t) db, mock, err := sqlmock.New() require.NoError(err) require.NoError(err) dbw, err := OpenWith(pgDriver.New(pgDriver.Config{ Conn: db, })) require.NoError(err) return dbw, mock } // getTestOpts - iterate the inbound TestOptions and return a struct func getTestOpts(opt ...TestOption) testOptions { opts := getDefaultTestOptions() for _, o := range opt { o(&opts) } return opts } // TestOption - how Options are passed as arguments type TestOption func(*testOptions) // options = how options are represented type testOptions struct { withDialect string withTestDatabaseUrl string withTestMigration func(ctx context.Context, dialect, url string) error withTestMigrationUsingDb func(ctx context.Context, db *sql.DB) error withTestDebug bool } func getDefaultTestOptions() testOptions { return testOptions{} } // WithTestDialect provides a way to specify the test database dialect func WithTestDialect(dialect string) TestOption { return func(o *testOptions) { o.withDialect = dialect } } // WithTestMigration provides a way to specify an option func which runs a // required database migration to initialize the database func WithTestMigration(migrationFn func(ctx context.Context, dialect, url string) error) TestOption { return func(o *testOptions) { o.withTestMigration = migrationFn } } // WithTestMigrationUsingDB provides a way to specify an option func which runs a // required database migration to initialize the database using an existing open // sql.DB func WithTestMigrationUsingDB(migrationFn func(ctx context.Context, db *sql.DB) error) TestOption { return func(o *testOptions) { o.withTestMigrationUsingDb = migrationFn } } // WithTestDatabaseUrl provides a way to specify an existing database for tests func WithTestDatabaseUrl(url string) TestOption { return func(o *testOptions) { o.withTestDatabaseUrl = url } } // TestCreateTables will create the test tables for the dbw pkg func TestCreateTables(t *testing.T, conn *DB) { t.Helper() require := require.New(t) testCtx := context.Background() rw := New(conn) var query string switch conn.wrapped.Dialector.Name() { case "sqlite": query = testQueryCreateTablesSqlite case "postgres": query = testQueryCreateTablesPostgres default: t.Fatalf("unknown dialect: %s", conn.wrapped.Dialector.Name()) } _, err := rw.Exec(testCtx, query, nil) require.NoError(err) } func testDropTables(t *testing.T, conn *DB) { t.Helper() require := require.New(t) testCtx := context.Background() rw := New(conn) var query string switch conn.wrapped.Dialector.Name() { case "sqlite": query = testQueryDropTablesSqlite case "postgres": query = testQueryDropTablesPostgres default: t.Fatalf("unknown dialect: %s", conn.wrapped.Dialector.Name()) } _, err := rw.Exec(testCtx, query, nil) require.NoError(err) } const ( testQueryCreateTablesSqlite = ` begin; create table if not exists db_test_user ( public_id text not null constraint db_test_user_pkey primary key, create_time timestamp not null default current_timestamp, update_time timestamp not null default current_timestamp, name text unique, phone_number text, email text, version int default 1 ); create trigger update_time_column_db_test_user before update on db_test_user for each row when new.public_id <> old.public_id or new.name <> old.name or new.phone_number <> old.phone_number or new.email <> old.email or new.version <> old.version begin update db_test_user set update_time = datetime('now','localtime') where rowid == new.rowid; end; create trigger immutable_columns_db_test_user before update on db_test_user for each row when new.create_time <> old.create_time begin select raise(abort, 'immutable column'); end; create trigger default_create_time_column_db_test_user before insert on db_test_user for each row begin update db_test_user set create_time = datetime('now','localtime') where rowid = new.rowid; end; create trigger update_version_column_db_test_user after update on db_test_user for each row when new.public_id <> old.public_id or new.name <> old.name or new.phone_number <> old.phone_number or new.email <> old.email begin update db_test_user set version = old.version + 1 where rowid = new.rowid; end; create table if not exists db_test_car ( public_id text constraint db_test_car_pkey primary key, create_time timestamp not null default current_timestamp, update_time timestamp not null default current_timestamp, name text unique, model text, mpg smallint, version int default 1 ); create trigger update_time_column_db_test_car before update on db_test_car for each row when new.public_id <> old.public_id or new.name <> old.name or new.model <> old.model or new.mpg <> old.mpg or new.version <> old.version begin update db_test_car set update_time = datetime('now','localtime') where rowid == new.rowid; end; create trigger default_create_time_column_db_test_car before insert on db_test_car for each row begin update db_test_car set create_time = datetime('now','localtime') where rowid = new.rowid; end; create trigger update_version_column_db_test_car after update on db_test_car for each row when new.public_id <> old.public_id or new.name <> old.name or new.model <> old.model or new.mpg <> old.mpg begin update db_test_car set version = old.version + 1 where rowid = new.rowid; end; create table if not exists db_test_rental ( user_id text not null references db_test_user(public_id), car_id text not null references db_test_car(public_id), create_time timestamp not null default current_timestamp, update_time timestamp not null default current_timestamp, name text unique, version int default 1, constraint db_test_rental_pkey primary key(user_id, car_id) ); create trigger update_time_column_db_test_rental before update on db_test_rental for each row when new.user_id <> old.user_id or new.car_id <> old.car_id or new.name <> old.name or new.version <> old.version begin update db_test_rental set update_time = datetime('now','localtime') where rowid == new.rowid; end; create trigger immutable_columns_db_test_rental before update on db_test_rental for each row when new.create_time <> old.create_time begin select raise(abort, 'immutable column'); end; create trigger default_create_time_column_db_test_rental before insert on db_test_rental for each row begin update db_test_rental set create_time = datetime('now','localtime') where rowid = new.rowid; end; create trigger update_version_column_db_test_rental after update on db_test_rental for each row when new.user_id <> old.user_id or new.car_id <> old.car_id or new.name <> old.name or new.version <> old.version begin update db_test_rental set version = old.version + 1 where rowid = new.rowid; end; create table if not exists db_test_scooter ( private_id text constraint db_test_scooter_pkey primary key, create_time timestamp not null default current_timestamp, update_time timestamp not null default current_timestamp, name text unique, model text, mpg smallint, version int default 1 ); create trigger update_time_column_db_test_scooter before update on db_test_scooter for each row when new.private_id <> old.private_id or new.name <> old.name or new.model <> old.model or new.mpg <> old.mpg or new.version <> old.version begin update db_test_scooter set update_time = datetime('now','localtime') where rowid == new.rowid; end; create trigger default_create_time_column_db_test_scooter before insert on db_test_scooter for each row begin update db_test_scooter set create_time = datetime('now','localtime') where rowid = new.rowid; end; create trigger update_version_column_db_test_scooter after update on db_test_scooter for each row when new.private_id <> old.private_id or new.name <> old.name or new.model <> old.model or new.mpg <> old.mpg begin update db_test_scooter set version = old.version + 1 where rowid = new.rowid; end; commit; ` testQueryCreateTablesPostgres = ` begin; create domain wt_public_id as text check( length(trim(value)) > 10 ); comment on domain wt_public_id is 'Random ID generated with github.com/hashicorp/go-secure-stdlib/base62'; create domain wt_private_id as text not null check( length(trim(value)) > 10 ); comment on domain wt_private_id is 'Random ID generated with github.com/hashicorp/go-secure-stdlib/base62'; drop domain if exists wt_timestamp; create domain wt_timestamp as timestamp with time zone default current_timestamp; comment on domain wt_timestamp is 'Standard timestamp for all create_time and update_time columns'; create or replace function update_time_column() returns trigger as $$ begin if row(new.*) is distinct from row(old.*) then new.update_time = now(); return new; else return old; end if; end; $$ language plpgsql; comment on function update_time_column() is 'function used in before update triggers to properly set update_time columns'; create or replace function default_create_time() returns trigger as $$ begin if new.create_time is distinct from now() then raise warning 'create_time cannot be set to %', new.create_time; new.create_time = now(); end if; return new; end; $$ language plpgsql; comment on function default_create_time() is 'function used in before insert triggers to set create_time column to now'; create domain wt_version as bigint default 1 not null check( value > 0 ); comment on domain wt_version is 'standard column for row version'; -- update_version_column() will increment the version column whenever row data -- is updated and should only be used in an update after trigger. This function -- will overwrite any explicit updates to the version column. The function -- accepts an optional parameter of 'private_id' for the tables primary key. create or replace function update_version_column() returns trigger as $$ begin if pg_trigger_depth() = 1 then if row(new.*) is distinct from row(old.*) then if tg_nargs = 0 then execute format('update %I set version = $1 where public_id = $2', tg_relid::regclass) using old.version+1, new.public_id; new.version = old.version + 1; return new; end if; if tg_argv[0] = 'private_id' then execute format('update %I set version = $1 where private_id = $2', tg_relid::regclass) using old.version+1, new.private_id; new.version = old.version + 1; return new; end if; end if; end if; return new; end; $$ language plpgsql; comment on function update_version_column() is 'function used in after update triggers to properly set version columns'; -- immutable_columns() will make the column names immutable which are passed as -- parameters when the trigger is created. It raises error code 23601 which is a -- class 23 integrity constraint violation: immutable column create or replace function immutable_columns() returns trigger as $$ declare col_name text; new_value text; old_value text; begin foreach col_name in array tg_argv loop execute format('SELECT $1.%I', col_name) into new_value using new; execute format('SELECT $1.%I', col_name) into old_value using old; if new_value is distinct from old_value then raise exception 'immutable column: %.%', tg_table_name, col_name using errcode = '23601', schema = tg_table_schema, table = tg_table_name, column = col_name; end if; end loop; return new; end; $$ language plpgsql; comment on function immutable_columns() is 'function used in before update triggers to make columns immutable'; -- ######################################################################################## create table if not exists db_test_user ( public_id wt_public_id constraint db_test_user_pkey primary key, create_time wt_timestamp, update_time wt_timestamp, name text unique, phone_number text, email text, version wt_version ); create trigger update_time_column before update on db_test_user for each row execute procedure update_time_column(); -- define the immutable fields for db_test_user create trigger immutable_columns before update on db_test_user for each row execute procedure immutable_columns('create_time'); create trigger default_create_time_column before insert on db_test_user for each row execute procedure default_create_time(); create trigger update_version_column after update on db_test_user for each row execute procedure update_version_column(); create table if not exists db_test_car ( public_id wt_public_id constraint db_test_car_pkey primary key, create_time wt_timestamp, update_time wt_timestamp, name text unique, model text, mpg smallint ); create trigger update_time_column before update on db_test_car for each row execute procedure update_time_column(); -- define the immutable fields for db_test_car create trigger immutable_columns before update on db_test_car for each row execute procedure immutable_columns('create_time'); create trigger default_create_time_column before insert on db_test_car for each row execute procedure default_create_time(); create table if not exists db_test_rental ( user_id wt_public_id not null references db_test_user(public_id), car_id wt_public_id not null references db_test_car(public_id), create_time wt_timestamp, update_time wt_timestamp, name text unique, version wt_version, constraint db_test_rental_pkey primary key(user_id, car_id) ); create trigger update_time_column before update on db_test_rental for each row execute procedure update_time_column(); -- define the immutable fields for db_test_rental create trigger immutable_columns before update on db_test_rental for each row execute procedure immutable_columns('create_time'); create trigger default_create_time_column before insert on db_test_rental for each row execute procedure default_create_time(); -- update_version_column() will increment the version column whenever row data -- is updated and should only be used in an update after trigger. This function -- will overwrite any explicit updates to the version column. The function -- accepts an optional parameter of 'private_id' for the tables primary key. create or replace function update_rental_version_column() returns trigger as $$ begin if pg_trigger_depth() = 1 then if row(new.*) is distinct from row(old.*) then if tg_nargs = 0 then execute format('update %I set version = $1 where user_id = $2 and car_id = $3', tg_relid::regclass) using old.version+1, new.user_id, new.car_id; new.version = old.version + 1; return new; end if; end if; end if; return new; end; $$ language plpgsql; create trigger update_version_column after update on db_test_rental for each row execute procedure update_rental_version_column(); create table if not exists db_test_scooter ( private_id wt_private_id constraint db_test_scooter_pkey primary key, create_time wt_timestamp, update_time wt_timestamp, name text unique, model text, mpg smallint ); create trigger update_time_column before update on db_test_scooter for each row execute procedure update_time_column(); -- define the immutable fields for db_test_scooter create trigger immutable_columns before update on db_test_scooter for each row execute procedure immutable_columns('create_time'); create trigger default_create_time_column before insert on db_test_scooter for each row execute procedure default_create_time(); commit; ` testQueryDropTablesSqlite = ` begin; drop table if exists db_test_user; drop table if exists db_test_car; drop table if exists db_test_rental; drop table if exists db_test_scooter; commit; ` testQueryDropTablesPostgres = ` begin; drop table if exists db_test_user cascade; drop table if exists db_test_car cascade; drop table if exists db_test_rental cascade; drop table if exists db_test_scooter cascade; drop domain if exists wt_public_id; drop domain if exists wt_private_id; drop domain if exists wt_timestamp; drop domain if exists wt_version; commit; ` )
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package dbw import ( "context" "fmt" ) // Begin will start a transaction func (rw *RW) Begin(ctx context.Context) (*RW, error) { const op = "dbw.Begin" newTx := rw.underlying.wrapped.WithContext(ctx) newTx = newTx.Begin() if newTx.Error != nil { return nil, fmt.Errorf("%s: %w", op, newTx.Error) } return New( &DB{wrapped: newTx}, ), nil } // Rollback will rollback the current transaction func (rw *RW) Rollback(ctx context.Context) error { const op = "dbw.Rollback" db := rw.underlying.wrapped.WithContext(ctx) if err := db.Rollback().Error; err != nil { return fmt.Errorf("%s: %w", op, err) } return nil } // Commit will commit a transaction func (rw *RW) Commit(ctx context.Context) error { const op = "dbw.Commit" db := rw.underlying.wrapped.WithContext(ctx) if err := db.Commit().Error; err != nil { return fmt.Errorf("%s: %w", op, err) } return nil }
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package dbw import ( "context" "fmt" "reflect" "sync/atomic" "gorm.io/gorm" ) var nonUpdateFields atomic.Value // InitNonUpdatableFields sets the fields which are not updatable using // via RW.Update(...) func InitNonUpdatableFields(fields []string) { m := make(map[string]struct{}, len(fields)) for _, f := range fields { m[f] = struct{}{} } nonUpdateFields.Store(m) } // NonUpdatableFields returns the current set of fields which are not updatable using // via RW.Update(...) func NonUpdatableFields() []string { m := nonUpdateFields.Load() if m == nil { return []string{} } fields := make([]string, 0, len(m.(map[string]struct{}))) for f := range m.(map[string]struct{}) { fields = append(fields, f) } return fields } // Update a resource in the db, a fieldMask is required and provides // field_mask.proto paths for fields that should be updated. The i interface // parameter is the type the caller wants to update in the db and its fields are // set to the update values. setToNullPaths is optional and provides // field_mask.proto paths for the fields that should be set to null. // fieldMaskPaths and setToNullPaths must not intersect. The caller is // responsible for the transaction life cycle of the writer and if an error is // returned the caller must decide what to do with the transaction, which almost // always should be to rollback. Update returns the number of rows updated. // // Supported options: WithBeforeWrite, WithAfterWrite, WithWhere, WithDebug, // WithTable and WithVersion. If WithVersion is used, then the update will // include the version number in the update where clause, which basically makes // the update use optimistic locking and the update will only succeed if the // existing rows version matches the WithVersion option. Zero is not a valid // value for the WithVersion option and will return an error. WithWhere allows // specifying an additional constraint on the operation in addition to the PKs. // WithDebug will turn on debugging for the update call. func (rw *RW) Update(ctx context.Context, i interface{}, fieldMaskPaths []string, setToNullPaths []string, opt ...Option) (int, error) { const op = "dbw.Update" if rw.underlying == nil { return noRowsAffected, fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) } if isNil(i) { return noRowsAffected, fmt.Errorf("%s: missing interface: %w", op, ErrInvalidParameter) } if err := raiseErrorOnHooks(i); err != nil { return noRowsAffected, fmt.Errorf("%s: %w", op, err) } if len(fieldMaskPaths) == 0 && len(setToNullPaths) == 0 { return noRowsAffected, fmt.Errorf("%s: both fieldMaskPaths and setToNullPaths are missing: %w", op, ErrInvalidParameter) } opts := GetOpts(opt...) // we need to filter out some non-updatable fields (like: CreateTime, etc) fieldMaskPaths = filterPaths(fieldMaskPaths) setToNullPaths = filterPaths(setToNullPaths) if len(fieldMaskPaths) == 0 && len(setToNullPaths) == 0 { return noRowsAffected, fmt.Errorf("%s: after filtering non-updated fields, there are no fields left in fieldMaskPaths or setToNullPaths: %w", op, ErrInvalidParameter) } updateFields, err := UpdateFields(i, fieldMaskPaths, setToNullPaths) if err != nil { return noRowsAffected, fmt.Errorf("%s: getting update fields failed: %w", op, err) } if len(updateFields) == 0 { return noRowsAffected, fmt.Errorf("%s: no fields matched using fieldMaskPaths %s: %w", op, fieldMaskPaths, ErrInvalidParameter) } names, isZero, err := rw.primaryFieldsAreZero(ctx, i) if err != nil { return noRowsAffected, fmt.Errorf("%s: %w", op, err) } if isZero { return noRowsAffected, fmt.Errorf("%s: primary key is not set for: %s: %w", op, names, ErrInvalidParameter) } mDb := rw.underlying.wrapped.Model(i) err = mDb.Statement.Parse(i) if err != nil || mDb.Statement.Schema == nil { return noRowsAffected, fmt.Errorf("%s: internal error: unable to parse stmt: %w", op, err) } reflectValue := reflect.Indirect(reflect.ValueOf(i)) for _, pf := range mDb.Statement.Schema.PrimaryFields { if _, isZero := pf.ValueOf(ctx, reflectValue); isZero { return noRowsAffected, fmt.Errorf("%s: primary key %s is not set: %w", op, pf.Name, ErrInvalidParameter) } if contains(fieldMaskPaths, pf.Name) { return noRowsAffected, fmt.Errorf("%s: not allowed on primary key field %s: %w", op, pf.Name, ErrInvalidFieldMask) } } if !opts.WithSkipVetForWrite { if vetter, ok := i.(VetForWriter); ok { if err := vetter.VetForWrite(ctx, rw, UpdateOp, WithFieldMaskPaths(fieldMaskPaths), WithNullPaths(setToNullPaths)); err != nil { return noRowsAffected, fmt.Errorf("%s: %w", op, err) } } } if opts.WithBeforeWrite != nil { if err := opts.WithBeforeWrite(i); err != nil { return noRowsAffected, fmt.Errorf("%s: error before write: %w", op, err) } } underlying := rw.underlying.wrapped.Model(i) if opts.WithDebug { underlying = underlying.Debug() } if opts.WithTable != "" { underlying = underlying.Table(opts.WithTable) } switch { case opts.WithVersion != nil || opts.WithWhereClause != "": where, args, err := rw.whereClausesFromOpts(ctx, i, opts) if err != nil { return noRowsAffected, fmt.Errorf("%s: %w", op, err) } underlying = underlying.Where(where, args...).Updates(updateFields) default: underlying = underlying.Updates(updateFields) } if underlying.Error != nil { if underlying.Error == gorm.ErrRecordNotFound { return noRowsAffected, fmt.Errorf("%s: %w", op, gorm.ErrRecordNotFound) } return noRowsAffected, fmt.Errorf("%s: %w", op, underlying.Error) } rowsUpdated := int(underlying.RowsAffected) if rowsUpdated > 0 && (opts.WithAfterWrite != nil) { if err := opts.WithAfterWrite(i, rowsUpdated); err != nil { return rowsUpdated, fmt.Errorf("%s: error after write: %w", op, err) } } // we need to force a lookupAfterWrite so the resource returned is correctly initialized // from the db opt = append(opt, WithLookup(true)) if err := rw.lookupAfterWrite(ctx, i, opt...); err != nil { return noRowsAffected, fmt.Errorf("%s: %w", op, err) } return rowsUpdated, nil } // filterPaths will filter out non-updatable fields func filterPaths(paths []string) []string { if len(paths) == 0 { return nil } nonUpdatable := NonUpdatableFields() if len(nonUpdatable) == 0 { return paths } var filtered []string for _, p := range paths { switch { case contains(nonUpdatable, p): continue default: filtered = append(filtered, p) } } return filtered }