// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package dbw
import (
"math"
"math/rand"
"time"
)
// Backoff defines an interface for providing a back off for retrying
// transactions. See DoTx(...)
type Backoff interface {
Duration(attemptNumber uint) time.Duration
}
// ConstBackoff defines a constant backoff for retrying transactions. See
// DoTx(...)
type ConstBackoff struct {
DurationMs time.Duration
}
// Duration is the constant backoff duration based on the retry attempt
func (b ConstBackoff) Duration(attempt uint) time.Duration {
return time.Millisecond * time.Duration(b.DurationMs)
}
// ExpBackoff defines an exponential backoff for retrying transactions. See DoTx(...)
type ExpBackoff struct {
testRand float64
}
// Duration is the exponential backoff duration based on the retry attempt
func (b ExpBackoff) Duration(attempt uint) time.Duration {
var r float64
switch {
case b.testRand > 0:
r = b.testRand
default:
r = rand.Float64()
}
return time.Millisecond * time.Duration(math.Exp2(float64(attempt))*5*(r+0.5))
}
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package dbw
import (
"sort"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// ColumnValue defines a column and it's assigned value for a database
// operation. See: SetColumnValues(...)
type ColumnValue struct {
// Column name
Column string
// Value is the column's value
Value interface{}
}
// Column represents a table Column
type Column struct {
// Name of the column
Name string
// Table name of the column
Table string
}
func (c *Column) toAssignment(column string) clause.Assignment {
return clause.Assignment{
Column: clause.Column{Name: column},
Value: clause.Column{Table: c.Table, Name: c.Name},
}
}
func rawAssignment(column string, value interface{}) clause.Assignment {
return clause.Assignment{
Column: clause.Column{Name: column},
Value: value,
}
}
// ExprValue encapsulates an expression value for a column assignment. See
// Expr(...) to create these values.
type ExprValue struct {
Sql string
Vars []interface{}
}
func (ev *ExprValue) toAssignment(column string) clause.Assignment {
return clause.Assignment{
Column: clause.Column{Name: column},
Value: gorm.Expr(ev.Sql, ev.Vars...),
}
}
// Expr creates an expression value (ExprValue) which can be used when setting
// column values for database operations. See: Expr(...)
//
// Set name column to null example:
//
// SetColumnValues(map[string]interface{}{"name": Expr("NULL")})
//
// Set exp_time column to N seconds from now:
//
// SetColumnValues(map[string]interface{}{"exp_time": Expr("wt_add_seconds_to_now(?)", 10)})
func Expr(expr string, args ...interface{}) ExprValue {
return ExprValue{Sql: expr, Vars: args}
}
// SetColumnValues defines a map from column names to values for database
// operations.
func SetColumnValues(columnValues map[string]interface{}) []ColumnValue {
keys := make([]string, 0, len(columnValues))
for key := range columnValues {
keys = append(keys, key)
}
sort.Strings(keys)
assignments := make([]ColumnValue, len(keys))
for idx, key := range keys {
assignments[idx] = ColumnValue{Column: key, Value: columnValues[key]}
}
return assignments
}
// SetColumns defines a list of column (names) to update using the set of
// proposed insert columns during an on conflict update.
func SetColumns(names []string) []ColumnValue {
assignments := make([]ColumnValue, len(names))
for idx, name := range names {
assignments[idx] = ColumnValue{
Column: name,
Value: Column{Name: name, Table: "excluded"},
}
}
return assignments
}
// OnConflict specifies how to handle alternative actions to take when an insert
// results in a unique constraint or exclusion constraint error.
type OnConflict struct {
// Target specifies what conflict you want to define a policy for. This can
// be any one of these:
// Columns: the name of a specific column or columns
// Constraint: the name of a unique constraint
Target interface{}
// Action specifies the action to take on conflict. This can be any one of
// these:
// DoNothing: leaves the conflicting record as-is
// UpdateAll: updates all the columns of the conflicting record using the resource's data
// []ColumnValue: update a set of columns of the conflicting record using the set of assignments
Action interface{}
}
// Constraint defines database constraint name
type Constraint string
// Columns defines a set of column names
type Columns []string
// DoNothing defines an "on conflict" action of doing nothing
type DoNothing bool
// UpdateAll defines an "on conflict" action of updating all columns using the
// proposed insert column values
type UpdateAll bool
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package dbw
import (
"fmt"
"reflect"
"strings"
"gorm.io/gorm"
)
// UpdateFields will create a map[string]interface of the update values to be
// sent to the db. The map keys will be the field names for the fields to be
// updated. The caller provided fieldMaskPaths and setToNullPaths must not
// intersect. fieldMaskPaths and setToNullPaths cannot both be zero len.
func UpdateFields(i interface{}, fieldMaskPaths []string, setToNullPaths []string) (map[string]interface{}, error) {
const op = "dbw.UpdateFields"
if i == nil {
return nil, fmt.Errorf("%s: interface is missing: %w", op, ErrInvalidParameter)
}
if fieldMaskPaths == nil {
fieldMaskPaths = []string{}
}
if setToNullPaths == nil {
setToNullPaths = []string{}
}
if len(fieldMaskPaths) == 0 && len(setToNullPaths) == 0 {
return nil, fmt.Errorf("%s: both fieldMaskPaths and setToNullPaths are zero len: %w", op, ErrInvalidParameter)
}
inter, maskPaths, nullPaths, err := Intersection(fieldMaskPaths, setToNullPaths)
if err != nil {
return nil, fmt.Errorf("%s: %w", op, ErrInvalidParameter)
}
if len(inter) != 0 {
return nil, fmt.Errorf("%s: fieldMashPaths and setToNullPaths cannot intersect: %w", op, ErrInvalidParameter)
}
updateFields := map[string]interface{}{} // case sensitive update fields to values
found := map[string]struct{}{} // we need something to keep track of found fields (case insensitive)
val := reflect.Indirect(reflect.ValueOf(i))
structTyp := val.Type()
for i := 0; i < structTyp.NumField(); i++ {
if f, ok := maskPaths[strings.ToUpper(structTyp.Field(i).Name)]; ok {
updateFields[f] = val.Field(i).Interface()
found[strings.ToUpper(f)] = struct{}{}
continue
}
if f, ok := nullPaths[strings.ToUpper(structTyp.Field(i).Name)]; ok {
updateFields[f] = gorm.Expr("NULL")
found[strings.ToUpper(f)] = struct{}{}
continue
}
kind := structTyp.Field(i).Type.Kind()
if kind == reflect.Struct || kind == reflect.Ptr {
embType := structTyp.Field(i).Type
// check if the embedded field is exported via CanInterface()
if val.Field(i).CanInterface() {
embVal := reflect.Indirect(reflect.ValueOf(val.Field(i).Interface()))
// if it's a ptr to a struct, then we need a few more bits before proceeding.
if kind == reflect.Ptr {
embVal = val.Field(i).Elem()
if !embVal.IsValid() {
continue
}
embType = embVal.Type()
if embType.Kind() != reflect.Struct {
continue
}
}
for embFieldNum := 0; embFieldNum < embType.NumField(); embFieldNum++ {
if f, ok := maskPaths[strings.ToUpper(embType.Field(embFieldNum).Name)]; ok {
updateFields[f] = embVal.Field(embFieldNum).Interface()
found[strings.ToUpper(f)] = struct{}{}
}
if f, ok := nullPaths[strings.ToUpper(embType.Field(embFieldNum).Name)]; ok {
updateFields[f] = gorm.Expr("NULL")
found[strings.ToUpper(f)] = struct{}{}
}
}
continue
}
}
}
if missing := findMissingPaths(setToNullPaths, found); len(missing) != 0 {
return nil, fmt.Errorf("%s: null paths not found in resource: %s: %w", op, missing, ErrInvalidParameter)
}
if missing := findMissingPaths(fieldMaskPaths, found); len(missing) != 0 {
return nil, fmt.Errorf("%s: field mask paths not found in resource: %s: %w", op, missing, ErrInvalidParameter)
}
return updateFields, nil
}
func findMissingPaths(paths []string, foundPaths map[string]struct{}) []string {
notFound := []string{}
for _, f := range paths {
if _, ok := foundPaths[strings.ToUpper(f)]; !ok {
notFound = append(notFound, f)
}
}
return notFound
}
// Intersection is a case-insensitive search for intersecting values. Returns
// []string of the Intersection with values in lowercase, and map[string]string
// of the original av and bv, with the key set to uppercase and value set to the
// original
func Intersection(av, bv []string) ([]string, map[string]string, map[string]string, error) {
const op = "dbw.Intersection"
if av == nil {
return nil, nil, nil, fmt.Errorf("%s: av is missing: %w", op, ErrInvalidParameter)
}
if bv == nil {
return nil, nil, nil, fmt.Errorf("%s: bv is missing: %w", op, ErrInvalidParameter)
}
if len(av) == 0 && len(bv) == 0 {
return []string{}, map[string]string{}, map[string]string{}, nil
}
s := []string{}
ah := map[string]string{}
bh := map[string]string{}
for i := 0; i < len(av); i++ {
ah[strings.ToUpper(av[i])] = av[i]
}
for i := 0; i < len(bv); i++ {
k := strings.ToUpper(bv[i])
bh[k] = bv[i]
if _, found := ah[k]; found {
s = append(s, strings.ToLower(bh[k]))
}
}
return s, ah, bh, nil
}
// BuildUpdatePaths takes a map of field names to field values, field masks,
// fields allowed to be zero value, and returns both a list of field names to
// update and a list of field names that should be set to null.
func BuildUpdatePaths(fieldValues map[string]interface{}, fieldMask []string, allowZeroFields []string) (masks []string, nulls []string) {
for f, v := range fieldValues {
if !contains(fieldMask, f) {
continue
}
switch {
case isZero(v) && !contains(allowZeroFields, f):
nulls = append(nulls, f)
default:
masks = append(masks, f)
}
}
return masks, nulls
}
func isZero(i interface{}) bool {
return i == nil || reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface())
}
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package dbw
import (
"context"
"fmt"
"reflect"
"strings"
"sync/atomic"
"gorm.io/gorm/clause"
)
// OpType defines a set of database operation types
type OpType int
const (
// UnknownOp is an unknown operaton
UnknownOp OpType = 0
// CreateOp is a create operation
CreateOp OpType = 1
// UpdateOp is an update operation
UpdateOp OpType = 2
// DeleteOp is a delete operation
DeleteOp OpType = 3
// DefaultBatchSize is the default batch size for bulk operations like
// CreateItems. This value is used if the caller does not specify a size
// using the WithBatchSize(...) option. Note: some databases have a limit
// on the number of query parameters (postgres is currently 64k and sqlite
// is 32k) and/or size of a SQL statement (sqlite is currently 1bn bytes),
// so this value should be set to a value that is less than the limits for
// your target db.
// See:
// - https://www.postgresql.org/docs/current/limits.html
// - https://www.sqlite.org/limits.html
DefaultBatchSize = 1000
)
// VetForWriter provides an interface that Create and Update can use to vet the
// resource before before writing it to the db. For optType == UpdateOp,
// options WithFieldMaskPath and WithNullPaths are supported. For optType ==
// CreateOp, no options are supported
type VetForWriter interface {
VetForWrite(ctx context.Context, r Reader, opType OpType, opt ...Option) error
}
var nonCreateFields atomic.Value
// InitNonCreatableFields sets the fields which are not setable using
// via RW.Create(...)
func InitNonCreatableFields(fields []string) {
m := make(map[string]struct{}, len(fields))
for _, f := range fields {
m[f] = struct{}{}
}
nonCreateFields.Store(m)
}
// NonCreatableFields returns the current set of fields which are not setable using
// via RW.Create(...)
func NonCreatableFields() []string {
m := nonCreateFields.Load()
if m == nil {
return []string{}
}
fields := make([]string, 0, len(m.(map[string]struct{})))
for f := range m.(map[string]struct{}) {
fields = append(fields, f)
}
return fields
}
// Create a resource in the db with options: WithDebug, WithLookup,
// WithReturnRowsAffected, OnConflict, WithBeforeWrite, WithAfterWrite,
// WithVersion, WithTable, and WithWhere.
//
// OnConflict specifies alternative actions to take when an insert results in a
// unique constraint or exclusion constraint error. If WithVersion is used with
// OnConflict, then the update for on conflict will include the version number,
// which basically makes the update use optimistic locking and the update will
// only succeed if the existing rows version matches the WithVersion option.
// Zero is not a valid value for the WithVersion option and will return an
// error. WithWhere allows specifying an additional constraint on the on
// conflict operation in addition to the on conflict target policy (columns or
// constraint).
func (rw *RW) Create(ctx context.Context, i interface{}, opt ...Option) error {
const op = "dbw.Create"
if rw.underlying == nil {
return fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter)
}
if isNil(i) {
return fmt.Errorf("%s: missing interface: %w", op, ErrInvalidParameter)
}
if err := raiseErrorOnHooks(i); err != nil {
return fmt.Errorf("%s: %w", op, err)
}
opts := GetOpts(opt...)
// these fields should be nil, since they are not writeable and we want the
// db to manage them
setFieldsToNil(i, NonCreatableFields())
if !opts.WithSkipVetForWrite {
if vetter, ok := i.(VetForWriter); ok {
if err := vetter.VetForWrite(ctx, rw, CreateOp); err != nil {
return fmt.Errorf("%s: %w", op, err)
}
}
}
db := rw.underlying.wrapped.WithContext(ctx)
if opts.WithOnConflict != nil {
c := clause.OnConflict{}
switch opts.WithOnConflict.Target.(type) {
case Constraint:
c.OnConstraint = string(opts.WithOnConflict.Target.(Constraint))
case Columns:
columns := make([]clause.Column, 0, len(opts.WithOnConflict.Target.(Columns)))
for _, name := range opts.WithOnConflict.Target.(Columns) {
columns = append(columns, clause.Column{Name: name})
}
c.Columns = columns
default:
return fmt.Errorf("%s: invalid conflict target %v: %w", op, reflect.TypeOf(opts.WithOnConflict.Target), ErrInvalidParameter)
}
switch opts.WithOnConflict.Action.(type) {
case DoNothing:
c.DoNothing = true
case UpdateAll:
c.UpdateAll = true
case []ColumnValue:
updates := opts.WithOnConflict.Action.([]ColumnValue)
set := make(clause.Set, 0, len(updates))
for _, s := range updates {
// make sure it's not one of the std immutable columns
if contains([]string{"createtime", "publicid"}, strings.ToLower(s.Column)) {
return fmt.Errorf("%s: cannot do update on conflict for column %s: %w", op, s.Column, ErrInvalidParameter)
}
switch sv := s.Value.(type) {
case Column:
set = append(set, sv.toAssignment(s.Column))
case ExprValue:
set = append(set, sv.toAssignment(s.Column))
default:
set = append(set, rawAssignment(s.Column, s.Value))
}
}
c.DoUpdates = set
default:
return fmt.Errorf("%s: invalid conflict action %v: %w", op, reflect.TypeOf(opts.WithOnConflict.Action), ErrInvalidParameter)
}
if opts.WithVersion != nil || opts.WithWhereClause != "" {
where, args, err := rw.whereClausesFromOpts(ctx, i, opts)
if err != nil {
return fmt.Errorf("%s: %w", op, err)
}
whereConditions := db.Statement.BuildCondition(where, args...)
c.Where = clause.Where{Exprs: whereConditions}
}
db = db.Clauses(c)
}
if opts.WithDebug {
db = db.Debug()
}
if opts.WithTable != "" {
db = db.Table(opts.WithTable)
}
if opts.WithBeforeWrite != nil {
if err := opts.WithBeforeWrite(i); err != nil {
return fmt.Errorf("%s: error before write: %w", op, err)
}
}
tx := db.Create(i)
if tx.Error != nil {
return fmt.Errorf("%s: create failed: %w", op, tx.Error)
}
if opts.WithRowsAffected != nil {
*opts.WithRowsAffected = tx.RowsAffected
}
if tx.RowsAffected > 0 && opts.WithAfterWrite != nil {
if err := opts.WithAfterWrite(i, int(tx.RowsAffected)); err != nil {
return fmt.Errorf("%s: error after write: %w", op, err)
}
}
if err := rw.lookupAfterWrite(ctx, i, opt...); err != nil {
return fmt.Errorf("%s: %w", op, err)
}
return nil
}
// CreateItems will create multiple items of the same type. Supported options:
// WithBatchSize, WithDebug, WithBeforeWrite, WithAfterWrite,
// WithReturnRowsAffected, OnConflict, WithVersion, WithTable, and WithWhere.
// WithLookup is not a supported option.
func (rw *RW) CreateItems(ctx context.Context, createItems interface{}, opt ...Option) error {
const op = "dbw.CreateItems"
switch {
case rw.underlying == nil:
return fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter)
case isNil(createItems):
return fmt.Errorf("%s: missing items: %w", op, ErrInvalidParameter)
}
valCreateItems := reflect.ValueOf(createItems)
switch {
case valCreateItems.Kind() != reflect.Slice:
return fmt.Errorf("%s: not a slice: %w", op, ErrInvalidParameter)
case valCreateItems.Len() == 0:
return fmt.Errorf("%s: missing items: %w", op, ErrInvalidParameter)
}
if err := raiseErrorOnHooks(createItems); err != nil {
return fmt.Errorf("%s: %w", op, err)
}
opts := GetOpts(opt...)
switch {
case opts.WithLookup:
return fmt.Errorf("%s: with lookup not a supported option: %w", op, ErrInvalidParameter)
}
var foundType reflect.Type
for i := 0; i < valCreateItems.Len(); i++ {
// verify that createItems are all the same type and do some bits on each item
if i == 0 {
foundType = reflect.TypeOf(valCreateItems.Index(i).Interface())
}
currentType := reflect.TypeOf(valCreateItems.Index(i).Interface())
if currentType == nil {
return fmt.Errorf("%s: unable to determine type of item %d: %w", op, i, ErrInvalidParameter)
}
if foundType != currentType {
return fmt.Errorf("%s: create items contains disparate types. item %d is not a %s: %w", op, i, foundType.Name(), ErrInvalidParameter)
}
// these fields should be nil, since they are not writeable and we want the
// db to manage them
setFieldsToNil(valCreateItems.Index(i).Interface(), NonCreatableFields())
// vet each item
if !opts.WithSkipVetForWrite {
if vetter, ok := valCreateItems.Index(i).Interface().(VetForWriter); ok {
if err := vetter.VetForWrite(ctx, rw, CreateOp); err != nil {
return fmt.Errorf("%s: %w", op, err)
}
}
}
}
if opts.WithBeforeWrite != nil {
if err := opts.WithBeforeWrite(createItems); err != nil {
return fmt.Errorf("%s: error before write: %w", op, err)
}
}
db := rw.underlying.wrapped.WithContext(ctx)
if opts.WithOnConflict != nil {
c := clause.OnConflict{}
switch opts.WithOnConflict.Target.(type) {
case Constraint:
c.OnConstraint = string(opts.WithOnConflict.Target.(Constraint))
case Columns:
columns := make([]clause.Column, 0, len(opts.WithOnConflict.Target.(Columns)))
for _, name := range opts.WithOnConflict.Target.(Columns) {
columns = append(columns, clause.Column{Name: name})
}
c.Columns = columns
default:
return fmt.Errorf("%s: invalid conflict target %v: %w", op, reflect.TypeOf(opts.WithOnConflict.Target), ErrInvalidParameter)
}
switch opts.WithOnConflict.Action.(type) {
case DoNothing:
c.DoNothing = true
case UpdateAll:
c.UpdateAll = true
case []ColumnValue:
updates := opts.WithOnConflict.Action.([]ColumnValue)
set := make(clause.Set, 0, len(updates))
for _, s := range updates {
// make sure it's not one of the std immutable columns
if contains([]string{"createtime", "publicid"}, strings.ToLower(s.Column)) {
return fmt.Errorf("%s: cannot do update on conflict for column %s: %w", op, s.Column, ErrInvalidParameter)
}
switch sv := s.Value.(type) {
case Column:
set = append(set, sv.toAssignment(s.Column))
case ExprValue:
set = append(set, sv.toAssignment(s.Column))
default:
set = append(set, rawAssignment(s.Column, s.Value))
}
}
c.DoUpdates = set
default:
return fmt.Errorf("%s: invalid conflict action %v: %w", op, reflect.TypeOf(opts.WithOnConflict.Action), ErrInvalidParameter)
}
if opts.WithVersion != nil || opts.WithWhereClause != "" {
// this is a bit of a hack, but we need to pass in one of the items
// to get the where clause since we need to get the gorm Model and
// Parse the gorm statement to build the where clause
where, args, err := rw.whereClausesFromOpts(ctx, valCreateItems.Index(0).Interface(), opts)
if err != nil {
return fmt.Errorf("%s: %w", op, err)
}
whereConditions := db.Statement.BuildCondition(where, args...)
c.Where = clause.Where{Exprs: whereConditions}
}
db = db.Clauses(c)
}
if opts.WithDebug {
db = db.Debug()
}
if opts.WithTable != "" {
db = db.Table(opts.WithTable)
}
tx := db.CreateInBatches(createItems, opts.WithBatchSize)
if tx.Error != nil {
return fmt.Errorf("%s: create failed: %w", op, tx.Error)
}
if opts.WithRowsAffected != nil {
*opts.WithRowsAffected = tx.RowsAffected
}
if tx.RowsAffected > 0 && opts.WithAfterWrite != nil {
if err := opts.WithAfterWrite(createItems, int(tx.RowsAffected)); err != nil {
return fmt.Errorf("%s: error after write: %w", op, err)
}
}
return nil
}
func setFieldsToNil(i interface{}, fieldNames []string) {
// Note: error cases are not handled
_ = Clear(i, fieldNames, 2)
}
// Clear sets fields in the value pointed to by i to their zero value.
// Clear descends i to depth clearing fields at each level. i must be a
// pointer to a struct. Cycles in i are not detected.
//
// A depth of 2 will change i and i's children. A depth of 1 will change i
// but no children of i. A depth of 0 will return with no changes to i.
func Clear(i interface{}, fields []string, depth int) error {
const op = "dbw.Clear"
if len(fields) == 0 || depth == 0 {
return nil
}
fm := make(map[string]bool)
for _, f := range fields {
fm[f] = true
}
v := reflect.ValueOf(i)
switch v.Kind() {
case reflect.Ptr:
if v.IsNil() || v.Elem().Kind() != reflect.Struct {
return fmt.Errorf("%s: %w", op, ErrInvalidParameter)
}
clear(v, fm, depth)
default:
return fmt.Errorf("%s: %w", op, ErrInvalidParameter)
}
return nil
}
func clear(v reflect.Value, fields map[string]bool, depth int) {
if depth == 0 {
return
}
depth--
switch v.Kind() {
case reflect.Ptr:
clear(v.Elem(), fields, depth+1)
case reflect.Struct:
typeOfT := v.Type()
for i := 0; i < v.NumField(); i++ {
f := v.Field(i)
if ok := fields[typeOfT.Field(i).Name]; ok {
if f.IsValid() && f.CanSet() {
f.Set(reflect.Zero(f.Type()))
}
continue
}
clear(f, fields, depth)
}
}
}
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package dbw
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/hashicorp/go-hclog"
"github.com/jackc/pgconn"
_ "github.com/jackc/pgx/v5" // required to load postgres drivers
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// DbType defines a database type. It's not an exhaustive list of database
// types which can be used by the dbw package, since you can always use
// OpenWith(...) to connect to KnownDB types.
type DbType int
const (
// UnknownDB is an unknown db type
UnknownDB DbType = 0
// Postgres is a postgre db type
Postgres DbType = 1
// Sqlite is a sqlite db type
Sqlite DbType = 2
)
// String provides a string rep of the DbType.
func (db DbType) String() string {
return [...]string{
"unknown",
"postgres",
"sqlite",
}[db]
}
// StringToDbType provides a string to type conversion. If the type is known,
// then UnknownDB with and error is returned.
func StringToDbType(dialect string) (DbType, error) {
switch dialect {
case "postgres":
return Postgres, nil
case "sqlite":
return Sqlite, nil
default:
return UnknownDB, fmt.Errorf("%s is an unknown dialect", dialect)
}
}
// DB is a wrapper around whatever is providing the interface for database
// operations (typically an ORM). DB uses database/sql to maintain connection
// pool.
type DB struct {
wrapped *gorm.DB
}
// DbType will return the DbType and raw name of the connection type
func (db *DB) DbType() (typ DbType, rawName string, e error) {
rawName = db.wrapped.Dialector.Name()
typ, _ = StringToDbType(rawName)
return typ, rawName, nil
}
// Debug will enable/disable debug info for the connection
func (db *DB) Debug(on bool) {
if on {
// info level in the Gorm domain which maps to a debug level in this domain
db.LogLevel(Info)
} else {
// the default level in the gorm domain is: error level
db.LogLevel(Error)
}
}
// LogLevel defines a log level
type LogLevel int
const (
// Default specifies the default log level
Default LogLevel = iota
// Silent is the silent log level
Silent
// Error is the error log level
Error
// Warn is the warning log level
Warn
// Info is the info log level
Info
)
// LogLevel will set the logging level for the db
func (db *DB) LogLevel(l LogLevel) {
db.wrapped.Logger = db.wrapped.Logger.LogMode(logger.LogLevel(l))
}
// SqlDB returns the underlying sql.DB Note: this makes it possible to do
// things like set database/sql connection options like SetMaxIdleConns. If
// you're simply setting max/min connections then you should use the
// WithMinOpenConnections and WithMaxOpenConnections options when
// "opening" the database.
//
// Care should be take when deciding to use this for basic database operations
// like Exec, Query, etc since these functions are already provided by dbw.RW
// which provides a layer of encapsulation of the underlying database.
func (db *DB) SqlDB(_ context.Context) (*sql.DB, error) {
const op = "dbw.(DB).SqlDB"
if db.wrapped == nil {
return nil, fmt.Errorf("%s: missing underlying database: %w", op, ErrInternal)
}
return db.wrapped.DB()
}
// Close the database
//
// Note: Consider if you need to call Close() on the returned DB. Typically the
// answer is no, but there are occasions when it's necessary. See the sql.DB
// docs for more information.
func (db *DB) Close(ctx context.Context) error {
const op = "dbw.(DB).Close"
if db.wrapped == nil {
return fmt.Errorf("%s: missing underlying database: %w", op, ErrInternal)
}
underlying, err := db.wrapped.DB()
if err != nil {
return fmt.Errorf("%s: %w", op, err)
}
return underlying.Close()
}
// Open a database connection which is long-lived. The options of
// WithLogger, WithLogLevel and WithMaxOpenConnections are supported.
//
// Note: Consider if you need to call Close() on the returned DB. Typically the
// answer is no, but there are occasions when it's necessary. See the sql.DB
// docs for more information.
func Open(dbType DbType, connectionUrl string, opt ...Option) (*DB, error) {
const op = "dbw.Open"
if connectionUrl == "" {
return nil, fmt.Errorf("%s: missing connection url: %w", op, ErrInvalidParameter)
}
var dialect gorm.Dialector
switch dbType {
case Postgres:
dialect = postgres.New(postgres.Config{
DSN: connectionUrl,
},
)
case Sqlite:
dialect = sqlite.Open(connectionUrl)
default:
return nil, fmt.Errorf("unable to open %s database type", dbType)
}
db, err := openDialector(dialect, opt...)
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
if dbType == Sqlite {
if _, err := New(db).Exec(context.Background(), "PRAGMA foreign_keys=ON", nil); err != nil {
return nil, fmt.Errorf("%s: unable to enable sqlite foreign keys: %w", op, err)
}
}
return db, nil
}
// Dialector provides a set of functions the database dialect must satisfy to
// be used with OpenWith(...)
// It's a simple wrapper of the gorm.Dialector and provides the ability to open
// any support gorm dialect driver.
type Dialector interface {
gorm.Dialector
}
// OpenWith will open a database connection using a Dialector which is
// long-lived. The options of WithLogger, WithLogLevel and
// WithMaxOpenConnections are supported.
//
// Note: Consider if you need to call Close() on the returned DB. Typically the
// answer is no, but there are occasions when it's necessary. See the sql.DB
// docs for more information.
func OpenWith(dialector Dialector, opt ...Option) (*DB, error) {
return openDialector(dialector, opt...)
}
func openDialector(dialect gorm.Dialector, opt ...Option) (*DB, error) {
db, err := gorm.Open(dialect, &gorm.Config{})
if err != nil {
return nil, fmt.Errorf("unable to open database: %w", err)
}
if strings.ToLower(dialect.Name()) == "sqlite" {
if err := db.Exec("PRAGMA foreign_keys=ON", nil).Error; err != nil {
return nil, fmt.Errorf("unable to enable sqlite foreign keys: %w", err)
}
}
opts := GetOpts(opt...)
if opts.WithLogger != nil {
newLogger := logger.New(
getGormLogger(opts.WithLogger),
logger.Config{
LogLevel: logger.LogLevel(opts.withLogLevel), // Log level
Colorful: false, // Disable color
},
)
db = db.Session(&gorm.Session{Logger: newLogger})
}
if opts.WithMaxOpenConnections > 0 {
if opts.WithMinOpenConnections > 0 && (opts.WithMaxOpenConnections < opts.WithMinOpenConnections) {
return nil, fmt.Errorf("unable to create db object with dialect %s: %s", dialect, fmt.Sprintf("max_open_connections must be unlimited by setting 0 or at least %d", opts.WithMinOpenConnections))
}
underlyingDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("unable retrieve db: %w", err)
}
underlyingDB.SetMaxOpenConns(opts.WithMaxOpenConnections)
}
return &DB{wrapped: db}, nil
}
type gormLogger struct {
logger hclog.Logger
}
func (g gormLogger) Printf(_ string, values ...interface{}) {
if len(values) > 1 {
switch values[1].(type) {
case *pgconn.PgError:
g.logger.Trace("error from database adapter", "location", values[0], "error", values[1])
}
}
}
func getGormLogger(log hclog.Logger) gormLogger {
return gormLogger{logger: log}
}
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package dbw
import (
"context"
"fmt"
"reflect"
)
// Delete a resource in the db with options: WithWhere, WithDebug, WithTable,
// and WithVersion. WithWhere and WithVersion allows specifying a additional
// constraints on the operation in addition to the PKs. Delete returns the
// number of rows deleted and any errors.
func (rw *RW) Delete(ctx context.Context, i interface{}, opt ...Option) (int, error) {
const op = "dbw.Delete"
if rw.underlying == nil {
return noRowsAffected, fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter)
}
if isNil(i) {
return noRowsAffected, fmt.Errorf("%s: missing interface: %w", op, ErrInvalidParameter)
}
if err := raiseErrorOnHooks(i); err != nil {
return noRowsAffected, fmt.Errorf("%s: %w", op, err)
}
opts := GetOpts(opt...)
mDb := rw.underlying.wrapped.Model(i)
err := mDb.Statement.Parse(i)
if err == nil && mDb.Statement.Schema == nil {
return noRowsAffected, fmt.Errorf("%s: (internal error) unable to parse stmt: %w", op, ErrUnknown)
}
reflectValue := reflect.Indirect(reflect.ValueOf(i))
for _, pf := range mDb.Statement.Schema.PrimaryFields {
if _, isZero := pf.ValueOf(ctx, reflectValue); isZero {
return noRowsAffected, fmt.Errorf("%s: primary key %s is not set: %w", op, pf.Name, ErrInvalidParameter)
}
}
if opts.WithBeforeWrite != nil {
if err := opts.WithBeforeWrite(i); err != nil {
return noRowsAffected, fmt.Errorf("%s: error before write: %w", op, err)
}
}
db := rw.underlying.wrapped.WithContext(ctx)
if opts.WithVersion != nil || opts.WithWhereClause != "" {
where, args, err := rw.whereClausesFromOpts(ctx, i, opts)
if err != nil {
return noRowsAffected, fmt.Errorf("%s: %w", op, err)
}
db = db.Where(where, args...)
}
if opts.WithDebug {
db = db.Debug()
}
if opts.WithTable != "" {
db = db.Table(opts.WithTable)
}
db = db.Delete(i)
if db.Error != nil {
return noRowsAffected, fmt.Errorf("%s: %w", op, db.Error)
}
rowsDeleted := int(db.RowsAffected)
if rowsDeleted > 0 && opts.WithAfterWrite != nil {
if err := opts.WithAfterWrite(i, rowsDeleted); err != nil {
return rowsDeleted, fmt.Errorf("%s: error after write: %w", op, err)
}
}
return rowsDeleted, nil
}
// DeleteItems will delete multiple items of the same type. Options supported:
// WithWhereClause, WithDebug, WithTable
func (rw *RW) DeleteItems(ctx context.Context, deleteItems interface{}, opt ...Option) (int, error) {
const op = "dbw.DeleteItems"
switch {
case rw.underlying == nil:
return noRowsAffected, fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter)
case isNil(deleteItems):
return noRowsAffected, fmt.Errorf("%s: no interfaces to delete: %w", op, ErrInvalidParameter)
}
valDeleteItems := reflect.ValueOf(deleteItems)
switch {
case valDeleteItems.Kind() != reflect.Slice:
return noRowsAffected, fmt.Errorf("%s: not a slice: %w", op, ErrInvalidParameter)
case valDeleteItems.Len() == 0:
return noRowsAffected, fmt.Errorf("%s: missing items: %w", op, ErrInvalidParameter)
}
if err := raiseErrorOnHooks(deleteItems); err != nil {
return noRowsAffected, fmt.Errorf("%s: %w", op, err)
}
opts := GetOpts(opt...)
switch {
case opts.WithLookup:
return noRowsAffected, fmt.Errorf("%s: with lookup not a supported option: %w", op, ErrInvalidParameter)
case opts.WithVersion != nil:
return noRowsAffected, fmt.Errorf("%s: with version is not a supported option: %w", op, ErrInvalidParameter)
}
// we need to dig out the stmt so in just a sec we can make sure the PKs are
// set for all the items, so we'll just use the first item to do so.
mDb := rw.underlying.wrapped.Model(valDeleteItems.Index(0).Interface())
err := mDb.Statement.Parse(valDeleteItems.Index(0).Interface())
switch {
case err != nil:
return noRowsAffected, fmt.Errorf("%s: (internal error) error parsing stmt: %w", op, err)
case err == nil && mDb.Statement.Schema == nil:
return noRowsAffected, fmt.Errorf("%s: (internal error) unable to parse stmt: %w", op, ErrUnknown)
}
// verify that deleteItems are all the same type, among a myriad of
// other things on the set of items
var foundType reflect.Type
for i := 0; i < valDeleteItems.Len(); i++ {
if i == 0 {
foundType = reflect.TypeOf(valDeleteItems.Index(i).Interface())
}
currentType := reflect.TypeOf(valDeleteItems.Index(i).Interface())
switch {
case isNil(valDeleteItems.Index(i).Interface()) || currentType == nil:
return noRowsAffected, fmt.Errorf("%s: unable to determine type of item %d: %w", op, i, ErrInvalidParameter)
case foundType != currentType:
return noRowsAffected, fmt.Errorf("%s: items contain disparate types. item %d is not a %s: %w", op, i, foundType.Name(), ErrInvalidParameter)
}
if opts.WithWhereClause == "" {
// make sure the PK is set for the current item
reflectValue := reflect.Indirect(reflect.ValueOf(valDeleteItems.Index(i).Interface()))
for _, pf := range mDb.Statement.Schema.PrimaryFields {
if _, isZero := pf.ValueOf(ctx, reflectValue); isZero {
return noRowsAffected, fmt.Errorf("%s: primary key %s is not set: %w", op, pf.Name, ErrInvalidParameter)
}
}
}
}
if opts.WithBeforeWrite != nil {
if err := opts.WithBeforeWrite(deleteItems); err != nil {
return noRowsAffected, fmt.Errorf("%s: error before write: %w", op, err)
}
}
db := rw.underlying.wrapped.WithContext(ctx)
if opts.WithDebug {
db = db.Debug()
}
if opts.WithWhereClause != "" {
where, args, err := rw.whereClausesFromOpts(ctx, valDeleteItems.Index(0).Interface(), opts)
if err != nil {
return noRowsAffected, fmt.Errorf("%s: %w", op, err)
}
db = db.Where(where, args...)
}
switch {
case opts.WithTable != "":
db = db.Table(opts.WithTable)
default:
tabler, ok := valDeleteItems.Index(0).Interface().(tableNamer)
if ok {
db = db.Table(tabler.TableName())
}
}
db = db.Delete(deleteItems)
if db.Error != nil {
return noRowsAffected, fmt.Errorf("%s: %w", op, db.Error)
}
rowsDeleted := int(db.RowsAffected)
if rowsDeleted > 0 && opts.WithAfterWrite != nil {
if err := opts.WithAfterWrite(deleteItems, int(rowsDeleted)); err != nil {
return rowsDeleted, fmt.Errorf("%s: error after write: %w", op, err)
}
}
return rowsDeleted, nil
}
type tableNamer interface {
TableName() string
}
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package dbw
import (
"context"
"fmt"
"time"
)
// DoTx will wrap the Handler func passed within a transaction with retries
// you should ensure that any objects written to the db in your TxHandler are retryable, which
// means that the object may be sent to the db several times (retried), so
// things like the primary key may need to be reset before retry.
func (rw *RW) DoTx(ctx context.Context, retryErrorsMatchingFn func(error) bool, retries uint, backOff Backoff, handler TxHandler) (RetryInfo, error) {
const op = "dbw.DoTx"
if rw.underlying == nil {
return RetryInfo{}, fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter)
}
if backOff == nil {
return RetryInfo{}, fmt.Errorf("%s: missing backoff: %w", op, ErrInvalidParameter)
}
if handler == nil {
return RetryInfo{}, fmt.Errorf("%s: missing handler: %w", op, ErrInvalidParameter)
}
if retryErrorsMatchingFn == nil {
return RetryInfo{}, fmt.Errorf("%s: missing retry errors matching function: %w", op, ErrInvalidParameter)
}
info := RetryInfo{}
for attempts := uint(1); ; attempts++ {
if attempts > retries+1 {
return info, fmt.Errorf("%s: too many retries: %d of %d: %w", op, attempts-1, retries+1, ErrMaxRetries)
}
// step one of this, start a transaction...
newTx := rw.underlying.wrapped.WithContext(ctx)
newTx = newTx.Begin()
newRW := &RW{underlying: &DB{newTx}}
if err := handler(newRW, newRW); err != nil {
if err := newTx.Rollback().Error; err != nil {
return info, fmt.Errorf("%s: %w", op, err)
}
if retry := retryErrorsMatchingFn(err); retry {
d := backOff.Duration(attempts)
info.Retries++
info.Backoff = info.Backoff + d
select {
case <-ctx.Done():
return info, fmt.Errorf("%s: cancelled: %w", op, err)
case <-time.After(d):
continue
}
}
return info, fmt.Errorf("%s: %w", op, err)
}
if err := newTx.Commit().Error; err != nil {
if err := newTx.Rollback().Error; err != nil {
return info, fmt.Errorf("%s: %w", op, err)
}
return info, fmt.Errorf("%s: %w", op, err)
}
return info, nil // it all worked!!!
}
}
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package dbw
import (
"bytes"
"fmt"
"strings"
"github.com/hashicorp/go-secure-stdlib/base62"
"golang.org/x/crypto/blake2b"
)
// NewId creates a new random base62 ID with the provided prefix with an
// underscore delimiter
func NewId(prefix string, opt ...Option) (string, error) {
const op = "dbw.NewId"
if prefix == "" {
return "", fmt.Errorf("%s: missing prefix: %w", op, ErrInvalidParameter)
}
var publicId string
var err error
opts := GetOpts(opt...)
if len(opts.WithPrngValues) > 0 {
sum := blake2b.Sum256([]byte(strings.Join(opts.WithPrngValues, "|")))
reader := bytes.NewReader(sum[0:])
publicId, err = base62.RandomWithReader(10, reader)
} else {
publicId, err = base62.Random(10)
}
if err != nil {
return "", fmt.Errorf("%s: unable to generate id: %w", op, ErrInternal)
}
return fmt.Sprintf("%s_%s", prefix, publicId), nil
}
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package dbw
import (
"context"
"fmt"
"gorm.io/gorm"
)
// LookupBy will lookup a resource by it's primary keys, which must be
// unique. If the resource implements either ResourcePublicIder or
// ResourcePrivateIder interface, then they are used as the resource's
// primary key for lookup. Otherwise, the resource tags are used to
// determine it's primary key(s) for lookup. The WithDebug and WithTable
// options are supported.
func (rw *RW) LookupBy(ctx context.Context, resourceWithIder interface{}, opt ...Option) error {
const op = "dbw.LookupById"
if rw.underlying == nil {
return fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter)
}
if err := raiseErrorOnHooks(resourceWithIder); err != nil {
return fmt.Errorf("%s: %w", op, err)
}
if err := validateResourcesInterface(resourceWithIder); err != nil {
return fmt.Errorf("%s: %w", op, err)
}
where, keys, err := rw.primaryKeysWhere(ctx, resourceWithIder)
if err != nil {
return fmt.Errorf("%s: %w", op, err)
}
opts := GetOpts(opt...)
db := rw.underlying.wrapped.WithContext(ctx)
if opts.WithTable != "" {
db = db.Table(opts.WithTable)
}
if opts.WithDebug {
db = db.Debug()
}
rw.clearDefaultNullResourceFields(ctx, resourceWithIder)
if err := db.Where(where, keys...).First(resourceWithIder).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return fmt.Errorf("%s: %w", op, ErrRecordNotFound)
}
return fmt.Errorf("%s: %w", op, err)
}
return nil
}
// LookupByPublicId will lookup resource by its public_id, which must be unique.
// The WithTable option is supported.
func (rw *RW) LookupByPublicId(ctx context.Context, resource ResourcePublicIder, opt ...Option) error {
return rw.LookupBy(ctx, resource, opt...)
}
func (rw *RW) lookupAfterWrite(ctx context.Context, i interface{}, opt ...Option) error {
const op = "dbw.lookupAfterWrite"
opts := GetOpts(opt...)
withLookup := opts.WithLookup
if err := raiseErrorOnHooks(i); err != nil {
return fmt.Errorf("%s: %w", op, err)
}
if !withLookup {
return nil
}
if err := rw.LookupBy(ctx, i, opt...); err != nil {
return fmt.Errorf("%s: %w", op, err)
}
return nil
}
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package dbw
import (
"github.com/hashicorp/go-hclog"
)
// GetOpts - iterate the inbound Options and return a struct.
func GetOpts(opt ...Option) Options {
opts := getDefaultOptions()
for _, o := range opt {
if o != nil {
o(&opts)
}
}
return opts
}
// Option - how Options are passed as arguments.
type Option func(*Options)
// Options - how Options are represented which have been set via an Option
// function. Use GetOpts(...) to populated this struct with the options that
// have been specified for an operation. All option fields are exported so
// they're available for use by other packages.
type Options struct {
// WithBeforeWrite provides and option to provide a func to be called before a
// write operation. The i interface{} passed at runtime will be the resource(s)
// being written.
WithBeforeWrite func(i interface{}) error
// WithAfterWrite provides and option to provide a func to be called after a
// write operation. The i interface{} passed at runtime will be the resource(s)
// being written.
WithAfterWrite func(i interface{}, rowsAffected int) error
// WithLookup enables a lookup after a write operation.
WithLookup bool
// WithLimit provides an option to provide a limit. Intentionally allowing
// negative integers. If WithLimit < 0, then unlimited results are returned.
// If WithLimit == 0, then default limits are used for results (see DefaultLimit
// const).
WithLimit int
// WithFieldMaskPaths provides an option to provide field mask paths for update
// operations.
WithFieldMaskPaths []string
// WithNullPaths provides an option to provide null paths for update
// operations.
WithNullPaths []string
// WithVersion provides an option version number for update operations. Using
// this option requires that your resource has a version column that's
// incremented for every successful update operation. Version provides an
// optimistic locking mechanism for write operations.
WithVersion *uint32
WithSkipVetForWrite bool
// WithWhereClause provides an option to provide a where clause for an
// operation.
WithWhereClause string
// WithWhereClauseArgs provides an option to provide a where clause arguments for an
// operation.
WithWhereClauseArgs []interface{}
// WithOrder provides an option to provide an order when searching and looking
// up.
WithOrder string
// WithPrngValues provides an option to provide values to seed an PRNG when generating IDs
WithPrngValues []string
// WithLogger specifies an optional hclog to use for db operations. It's only
// valid for Open(..) and OpenWith(...)
WithLogger hclog.Logger
// WithMinOpenConnections specifies and optional min open connections for the
// database. A value of zero means that there is no min.
WithMaxOpenConnections int
// WithMaxOpenConnections specifies and optional max open connections for the
// database. A value of zero equals unlimited connections
WithMinOpenConnections int
// WithDebug indicates that the given operation should invoke debug output
// mode
WithDebug bool
// WithOnConflict specifies an optional on conflict criteria which specify
// alternative actions to take when an insert results in a unique constraint or
// exclusion constraint error
WithOnConflict *OnConflict
// WithRowsAffected specifies an option for returning the rows affected
// and typically used with "bulk" write operations.
WithRowsAffected *int64
// WithTable specifies an option for setting a table name to use for the
// operation.
WithTable string
// WithBatchSize specifies an option for setting the batch size for bulk
// operations. If WithBatchSize == 0, then the default batch size is used.
WithBatchSize int
withLogLevel LogLevel
}
func getDefaultOptions() Options {
return Options{
WithFieldMaskPaths: []string{},
WithNullPaths: []string{},
WithBatchSize: DefaultBatchSize,
withLogLevel: Error,
}
}
// WithBeforeWrite provides and option to provide a func to be called before a
// write operation. The i interface{} passed at runtime will be the resource(s)
// being written.
func WithBeforeWrite(fn func(i interface{}) error) Option {
return func(o *Options) {
o.WithBeforeWrite = fn
}
}
// WithAfterWrite provides and option to provide a func to be called after a
// write operation. The i interface{} passed at runtime will be the resource(s)
// being written.
func WithAfterWrite(fn func(i interface{}, rowsAffected int) error) Option {
return func(o *Options) {
o.WithAfterWrite = fn
}
}
// WithLookup enables a lookup after a write operation.
func WithLookup(enable bool) Option {
return func(o *Options) {
o.WithLookup = enable
}
}
// WithFieldMaskPaths provides an option to provide field mask paths for update
// operations.
func WithFieldMaskPaths(paths []string) Option {
return func(o *Options) {
o.WithFieldMaskPaths = paths
}
}
// WithNullPaths provides an option to provide null paths for update operations.
func WithNullPaths(paths []string) Option {
return func(o *Options) {
o.WithNullPaths = paths
}
}
// WithLimit provides an option to provide a limit. Intentionally allowing
// negative integers. If WithLimit < 0, then unlimited results are returned.
// If WithLimit == 0, then default limits are used for results (see DefaultLimit
// const).
func WithLimit(limit int) Option {
return func(o *Options) {
o.WithLimit = limit
}
}
// WithVersion provides an option version number for update operations. Using
// this option requires that your resource has a version column that's
// incremented for every successful update operation. Version provides an
// optimistic locking mechanism for write operations.
func WithVersion(version *uint32) Option {
return func(o *Options) {
o.WithVersion = version
}
}
// WithSkipVetForWrite provides an option to allow skipping vet checks to allow
// testing lower-level SQL triggers and constraints
func WithSkipVetForWrite(enable bool) Option {
return func(o *Options) {
o.WithSkipVetForWrite = enable
}
}
// WithWhere provides an option to provide a where clause with arguments for an
// operation.
func WithWhere(whereClause string, args ...interface{}) Option {
return func(o *Options) {
o.WithWhereClause = whereClause
o.WithWhereClauseArgs = append(o.WithWhereClauseArgs, args...)
}
}
// WithOrder provides an option to provide an order when searching and looking
// up.
func WithOrder(withOrder string) Option {
return func(o *Options) {
o.WithOrder = withOrder
}
}
// WithPrngValues provides an option to provide values to seed an PRNG when generating IDs
func WithPrngValues(withPrngValues []string) Option {
return func(o *Options) {
o.WithPrngValues = withPrngValues
}
}
// WithLogger specifies an optional hclog to use for db operations. It's only
// valid for Open(..) and OpenWith(...)
func WithLogger(l hclog.Logger) Option {
return func(o *Options) {
o.WithLogger = l
}
}
// WithMaxOpenConnections specifies and optional max open connections for the
// database. A value of zero equals unlimited connections
func WithMaxOpenConnections(max int) Option {
return func(o *Options) {
o.WithMaxOpenConnections = max
}
}
// WithMinOpenConnections specifies and optional min open connections for the
// database. A value of zero means that there is no min.
func WithMinOpenConnections(max int) Option {
return func(o *Options) {
o.WithMinOpenConnections = max
}
}
// WithDebug specifies the given operation should invoke debug mode for the
// database output
func WithDebug(with bool) Option {
return func(o *Options) {
o.WithDebug = with
}
}
// WithOnConflict specifies an optional on conflict criteria which specify
// alternative actions to take when an insert results in a unique constraint or
// exclusion constraint error
func WithOnConflict(onConflict *OnConflict) Option {
return func(o *Options) {
o.WithOnConflict = onConflict
}
}
// WithReturnRowsAffected specifies an option for returning the rows affected
// and typically used with "bulk" write operations.
func WithReturnRowsAffected(rowsAffected *int64) Option {
return func(o *Options) {
o.WithRowsAffected = rowsAffected
}
}
// WithTable specifies an option for setting a table name to use for the
// operation.
func WithTable(name string) Option {
return func(o *Options) {
o.WithTable = name
}
}
// WithLogLevel specifies an option for setting the log level
func WithLogLevel(l LogLevel) Option {
return func(o *Options) {
o.withLogLevel = l
}
}
// WithBatchSize specifies an option for setting the batch size for bulk
// operations like CreateItems. If WithBatchSize == 0, the default batch size is
// used (see DefaultBatchSize const).
func WithBatchSize(size int) Option {
return func(o *Options) {
o.WithBatchSize = size
}
}
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package dbw
import (
"context"
"database/sql"
"fmt"
)
// Query will run the raw query and return the *sql.Rows results. Query will
// operate within the context of any ongoing transaction for the Reader. The
// caller must close the returned *sql.Rows. Query can/should be used in
// combination with ScanRows. The WithDebug option is supported.
func (rw *RW) Query(ctx context.Context, sql string, values []interface{}, opt ...Option) (*sql.Rows, error) {
const op = "dbw.Query"
if rw.underlying == nil {
return nil, fmt.Errorf("%s: missing underlying db: %w", op, ErrInternal)
}
if sql == "" {
return nil, fmt.Errorf("%s: missing sql: %w", op, ErrInvalidParameter)
}
opts := GetOpts(opt...)
db := rw.underlying.wrapped.WithContext(ctx)
if opts.WithDebug {
db = db.Debug()
}
db = db.Raw(sql, values...)
if db.Error != nil {
return nil, fmt.Errorf("%s: %w", op, db.Error)
}
return db.Rows()
}
// ScanRows will scan the rows into the interface
func (rw *RW) ScanRows(rows *sql.Rows, result interface{}) error {
const op = "dbw.ScanRows"
if rw.underlying == nil {
return fmt.Errorf("%s: missing underlying db: %w", op, ErrInternal)
}
if rows == nil {
return fmt.Errorf("%s: missing rows: %w", op, ErrInvalidParameter)
}
if isNil(result) {
return fmt.Errorf("%s: missing result: %w", op, ErrInvalidParameter)
}
return rw.underlying.wrapped.ScanRows(rows, result)
}
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package dbw
import (
"context"
"fmt"
"reflect"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
)
const (
noRowsAffected = 0
// DefaultLimit is the default for search results when no limit is specified
// via the WithLimit(...) option
DefaultLimit = 10000
)
// RW uses a DB as a connection for it's read/write operations. This is
// basically the primary type for the package's operations.
type RW struct {
underlying *DB
}
// ensure that RW implements the interfaces of: Reader and Writer
var (
_ Reader = (*RW)(nil)
_ Writer = (*RW)(nil)
)
// New creates a new RW using an open DB. Note: there can by many RWs that share
// the same DB, since the DB manages the connection pool.
func New(underlying *DB) *RW {
return &RW{underlying: underlying}
}
// DB returns the underlying DB
func (rw *RW) DB() *DB {
return rw.underlying
}
// Exec will execute the sql with the values as parameters. The int returned
// is the number of rows affected by the sql. The WithDebug option is supported.
func (rw *RW) Exec(ctx context.Context, sql string, values []interface{}, opt ...Option) (int, error) {
const op = "dbw.Exec"
if rw.underlying == nil {
return 0, fmt.Errorf("%s: missing underlying db: %w", op, ErrInternal)
}
if sql == "" {
return noRowsAffected, fmt.Errorf("%s: missing sql: %w", op, ErrInvalidParameter)
}
opts := GetOpts(opt...)
db := rw.underlying.wrapped.WithContext(ctx)
if opts.WithDebug {
db = db.Debug()
}
db = db.Exec(sql, values...)
if db.Error != nil {
return noRowsAffected, fmt.Errorf("%s: %w", op, db.Error)
}
return int(db.RowsAffected), nil
}
func (rw *RW) primaryFieldsAreZero(ctx context.Context, i interface{}) ([]string, bool, error) {
const op = "dbw.primaryFieldsAreZero"
var fieldNames []string
tx := rw.underlying.wrapped.Model(i)
if err := tx.Statement.Parse(i); err != nil {
return nil, false, fmt.Errorf("%s: %w", op, ErrInvalidParameter)
}
for _, f := range tx.Statement.Schema.PrimaryFields {
if f.PrimaryKey {
if _, isZero := f.ValueOf(ctx, reflect.ValueOf(i)); isZero {
fieldNames = append(fieldNames, f.Name)
}
}
}
return fieldNames, len(fieldNames) > 0, nil
}
func isNil(i interface{}) bool {
if i == nil {
return true
}
switch reflect.TypeOf(i).Kind() {
case reflect.Ptr, reflect.Map, reflect.Chan, reflect.Slice:
return reflect.ValueOf(i).IsNil()
}
return false
}
func contains(ss []string, t string) bool {
for _, s := range ss {
if strings.EqualFold(s, t) {
return true
}
}
return false
}
func validateResourcesInterface(resources interface{}) error {
const op = "dbw.validateResourcesInterface"
vo := reflect.ValueOf(resources)
if vo.Kind() != reflect.Ptr {
return fmt.Errorf("%s: interface parameter must to be a pointer: %w", op, ErrInvalidParameter)
}
e := vo.Elem()
if e.Kind() == reflect.Slice {
if e.Type().Elem().Kind() != reflect.Ptr {
return fmt.Errorf("%s: interface parameter is a slice, but the elements of the slice are not pointers: %w", op, ErrInvalidParameter)
}
}
return nil
}
func raiseErrorOnHooks(i interface{}) error {
const op = "dbw.raiseErrorOnHooks"
v := i
valOf := reflect.ValueOf(i)
if valOf.Kind() == reflect.Slice {
if valOf.Len() == 0 {
return nil
}
v = valOf.Index(0).Interface()
}
switch v.(type) {
case
// create hooks
callbacks.BeforeCreateInterface,
callbacks.AfterCreateInterface,
callbacks.BeforeSaveInterface,
callbacks.AfterSaveInterface,
// update hooks
callbacks.BeforeUpdateInterface,
callbacks.AfterUpdateInterface,
// delete hooks
callbacks.BeforeDeleteInterface,
callbacks.AfterDeleteInterface,
// find hooks
callbacks.AfterFindInterface:
return fmt.Errorf("%s: gorm callback/hooks are not supported: %w", op, ErrInvalidParameter)
}
return nil
}
// IsTx returns true if there's an existing transaction in progress
func (rw *RW) IsTx() bool {
if rw.underlying == nil {
return false
}
switch rw.underlying.wrapped.Statement.ConnPool.(type) {
case gorm.TxBeginner, gorm.ConnPoolBeginner:
return false
default:
return true
}
}
func (rw *RW) whereClausesFromOpts(_ context.Context, i interface{}, opts Options) (string, []interface{}, error) {
const op = "dbw.whereClausesFromOpts"
var where []string
var args []interface{}
if opts.WithVersion != nil {
if *opts.WithVersion == 0 {
return "", nil, fmt.Errorf("%s: with version option is zero: %w", op, ErrInvalidParameter)
}
mDb := rw.underlying.wrapped.Model(i)
err := mDb.Statement.Parse(i)
if err != nil && mDb.Statement.Schema == nil {
return "", nil, fmt.Errorf("%s: (internal error) unable to parse stmt: %w", op, ErrUnknown)
}
if !contains(mDb.Statement.Schema.DBNames, "version") {
return "", nil, fmt.Errorf("%s: %s does not have a version field: %w", op, mDb.Statement.Schema.Table, ErrInvalidParameter)
}
if opts.WithOnConflict != nil {
// on conflict clauses requires the version to be qualified with a
// table name
var tableName string
switch {
case opts.WithTable != "":
tableName = opts.WithTable
default:
tableName = mDb.Statement.Schema.Table
}
where = append(where, fmt.Sprintf("%s.version = ?", tableName)) // we need to include the table name because of "on conflict" use cases
} else {
where = append(where, "version = ?")
}
args = append(args, opts.WithVersion)
}
if opts.WithWhereClause != "" {
where, args = append(where, opts.WithWhereClause), append(args, opts.WithWhereClauseArgs...)
}
return strings.Join(where, " and "), args, nil
}
// clearDefaultNullResourceFields will clear fields in the resource which are
// defaulted to a null value. This addresses the unfixed issue in gorm:
// https://github.com/go-gorm/gorm/issues/6351
func (rw *RW) clearDefaultNullResourceFields(ctx context.Context, i interface{}) error {
const op = "dbw.ClearResourceFields"
stmt := rw.underlying.wrapped.Model(i).Statement
if err := stmt.Parse(i); err != nil {
return fmt.Errorf("%s: %w", op, err)
}
v := reflect.ValueOf(i)
for _, f := range stmt.Schema.Fields {
switch {
case f.PrimaryKey:
// seems a bit redundant, with the test for null, but it's very
// important to not clear the primary fields, so we'll make an
// explicit test
continue
case !f.Updatable:
// well, based on the gorm tags it's a read-only field, so we're done.
continue
case !strings.EqualFold(f.DefaultValue, "null"):
continue
default:
_, isZero := f.ValueOf(ctx, v)
if isZero {
continue
}
if err := f.Set(stmt.Context, v, f.DefaultValueInterface); err != nil {
return fmt.Errorf("%s: unable to set value of non-zero field: %w", op, err)
}
}
}
return nil
}
func (rw *RW) primaryKeysWhere(ctx context.Context, i interface{}) (string, []interface{}, error) {
const op = "dbw.primaryKeysWhere"
var fieldNames []string
var fieldValues []interface{}
tx := rw.underlying.wrapped.Model(i)
if err := tx.Statement.Parse(i); err != nil {
return "", nil, fmt.Errorf("%s: %w", op, err)
}
switch resourceType := i.(type) {
case ResourcePublicIder:
if resourceType.GetPublicId() == "" {
return "", nil, fmt.Errorf("%s: missing primary key: %w", op, ErrInvalidParameter)
}
fieldValues = []interface{}{resourceType.GetPublicId()}
fieldNames = []string{"public_id"}
case ResourcePrivateIder:
if resourceType.GetPrivateId() == "" {
return "", nil, fmt.Errorf("%s: missing primary key: %w", op, ErrInvalidParameter)
}
fieldValues = []interface{}{resourceType.GetPrivateId()}
fieldNames = []string{"private_id"}
default:
v := reflect.ValueOf(i)
for _, f := range tx.Statement.Schema.PrimaryFields {
if f.PrimaryKey {
val, isZero := f.ValueOf(ctx, v)
if isZero {
return "", nil, fmt.Errorf("%s: primary field %s is zero: %w", op, f.Name, ErrInvalidParameter)
}
fieldNames = append(fieldNames, f.DBName)
fieldValues = append(fieldValues, val)
}
}
}
if len(fieldNames) == 0 {
return "", nil, fmt.Errorf("%s: no primary key(s) for %t: %w", op, i, ErrInvalidParameter)
}
clauses := make([]string, 0, len(fieldNames))
for _, col := range fieldNames {
clauses = append(clauses, fmt.Sprintf("%s = ?", col))
}
return strings.Join(clauses, " and "), fieldValues, nil
}
// LookupWhere will lookup the first resource using a where clause with
// parameters (it only returns the first one). Supports WithDebug, and
// WithTable options.
func (rw *RW) LookupWhere(ctx context.Context, resource interface{}, where string, args []interface{}, opt ...Option) error {
const op = "dbw.LookupWhere"
if rw.underlying == nil {
return fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter)
}
if err := validateResourcesInterface(resource); err != nil {
return fmt.Errorf("%s: %w", op, err)
}
if err := raiseErrorOnHooks(resource); err != nil {
return fmt.Errorf("%s: %w", op, err)
}
opts := GetOpts(opt...)
db := rw.underlying.wrapped.WithContext(ctx)
if opts.WithTable != "" {
db = db.Table(opts.WithTable)
}
if opts.WithDebug {
db = db.Debug()
}
if err := db.Where(where, args...).First(resource).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return fmt.Errorf("%s: %w", op, ErrRecordNotFound)
}
return fmt.Errorf("%s: %w", op, err)
}
return nil
}
// SearchWhere will search for all the resources it can find using a where
// clause with parameters. An error will be returned if args are provided without a
// where clause.
//
// Supports WithTable and WithLimit options. If WithLimit < 0, then unlimited results are returned.
// If WithLimit == 0, then default limits are used for results.
// Supports the WithOrder, WithTable, and WithDebug options.
func (rw *RW) SearchWhere(ctx context.Context, resources interface{}, where string, args []interface{}, opt ...Option) error {
const op = "dbw.SearchWhere"
opts := GetOpts(opt...)
if rw.underlying == nil {
return fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter)
}
if where == "" && len(args) > 0 {
return fmt.Errorf("%s: args provided with empty where: %w", op, ErrInvalidParameter)
}
if err := raiseErrorOnHooks(resources); err != nil {
return fmt.Errorf("%s: %w", op, err)
}
if err := validateResourcesInterface(resources); err != nil {
return fmt.Errorf("%s: %w", op, err)
}
var err error
db := rw.underlying.wrapped.WithContext(ctx)
if opts.WithOrder != "" {
db = db.Order(opts.WithOrder)
}
if opts.WithDebug {
db = db.Debug()
}
if opts.WithTable != "" {
db = db.Table(opts.WithTable)
}
// Perform limiting
switch {
case opts.WithLimit < 0: // any negative number signals unlimited results
case opts.WithLimit == 0: // zero signals the default value and default limits
db = db.Limit(DefaultLimit)
default:
db = db.Limit(opts.WithLimit)
}
if where != "" {
db = db.Where(where, args...)
}
// Perform the query
err = db.Find(resources).Error
if err != nil {
// searching with a slice parameter does not return a gorm.ErrRecordNotFound
return fmt.Errorf("%s: %w", op, err)
}
return nil
}
func (rw *RW) Dialect() (_ DbType, rawName string, _ error) {
return rw.underlying.DbType()
}
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package dbw
import (
"context"
"database/sql"
"fmt"
"os"
"strings"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/xo/dburl"
pgDriver "gorm.io/driver/postgres"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm/logger"
)
// TestSetup is typically called before starting a test and will setup the
// database for the test (initialize the database one-time). Do not close the
// returned db. Supported test options: WithDebug, WithTestDialect,
// WithTestDatabaseUrl, WithTestMigration and WithTestMigrationUsingDB.
func TestSetup(t *testing.T, opt ...TestOption) (*DB, string) {
require := require.New(t)
var url string
var err error
InitNonUpdatableFields([]string{"CreateTime", "UpdateTime", "PublicId"})
InitNonCreatableFields([]string{"CreateTime", "UpdateTime"})
ctx := context.Background()
opts := getTestOpts(opt...)
switch strings.ToLower(os.Getenv("DB_DIALECT")) {
case "postgres":
opts.withDialect = Postgres.String()
case "sqlite":
opts.withDialect = Sqlite.String()
default:
if opts.withDialect == "" {
opts.withDialect = Sqlite.String()
}
}
if url := os.Getenv("DB_DSN"); url != "" {
opts.withTestDatabaseUrl = url
}
switch {
case opts.withDialect == Postgres.String() && opts.withTestDatabaseUrl == "":
t.Fatal("missing postgres test db url")
case opts.withDialect == Sqlite.String() && opts.withTestDatabaseUrl == "":
url = "file::memory:" // just using a temp in-memory sqlite database
default:
url = opts.withTestDatabaseUrl
}
switch opts.withDialect {
case Postgres.String():
u, err := dburl.Parse(opts.withTestDatabaseUrl)
require.NoError(err)
db, err := Open(Postgres, u.DSN)
require.NoError(err)
rw := New(db)
tmpDbName, err := NewId("go_db_tmp")
tmpDbName = strings.ToLower(tmpDbName)
require.NoError(err)
_, err = rw.Exec(ctx, fmt.Sprintf(`create database "%s"`, tmpDbName), nil)
require.NoError(err)
t.Cleanup(func() {
_, err = rw.Exec(ctx, `select pg_terminate_backend(pid) from pg_stat_activity where datname = ? and pid <> pg_backend_pid()`, []interface{}{tmpDbName})
assert.NoError(t, err)
_, err = rw.Exec(ctx, fmt.Sprintf(`drop database %s`, tmpDbName), nil)
assert.NoError(t, err)
})
_, err = rw.Exec(ctx, fmt.Sprintf(`grant all privileges on database %s to %s`, tmpDbName, u.User.Username()), nil)
require.NoError(err)
namesSegs := strings.Split(strings.TrimPrefix(u.Path, "/"), "?")
require.Truef(len(namesSegs) > 0, "couldn't determine db name from URL")
namesSegs[0] = tmpDbName
u.Path = strings.Join(namesSegs, "?")
url, _, err = dburl.GenPostgres(u)
require.NoError(err)
}
dbType, err := StringToDbType(opts.withDialect)
require.NoError(err)
db, err := Open(dbType, url)
require.NoError(err)
db.wrapped.Logger.LogMode(logger.Error)
t.Cleanup(func() {
assert.NoError(t, db.Close(ctx), "Got error closing db.")
})
if opts.withTestDebug || strings.ToLower(os.Getenv("DEBUG")) == "true" {
db.Debug(true)
}
// we're only going to run one set of migrations. Either one of the
// migration functions passed in as an option or the default
// TestCreateTables(...)
switch {
case opts.withTestMigration != nil:
err = opts.withTestMigration(ctx, opts.withDialect, url)
require.NoError(err)
case opts.withTestMigrationUsingDb != nil:
var rawDB *sql.DB
switch opts.withDialect {
case Sqlite.String():
// we need to special case handle sqlite because we may want to run
// the migration on an in-memory database that isn't shared
// "file::memory:" or ":memory:" so we have to be sure that we don't
// open a new connection, which would create a new in-memory db vs
// using the existing one already opened in this function a few
// lines above with: db, err := Open(dbType, url)
// see: https://www.sqlite.org/inmemorydb.html
//
// luckily, the gorm sqlite ConnPool is the existing opened *sql.DB,
// so we can just run the migration on that conn
var ok bool
rawDB, ok = db.wrapped.ConnPool.(*sql.DB)
require.True(ok, "expected the gorm ConnPool to be an *sql.DB")
default:
var err error
rawDB, err = db.wrapped.DB()
if err != nil {
require.NoError(err)
}
}
err = opts.withTestMigrationUsingDb(ctx, rawDB)
require.NoError(err)
default:
TestCreateTables(t, db)
}
return db, url
}
// TestSetupWithMock will return a test DB and an associated Sqlmock which can
// be used to mock out the db responses.
func TestSetupWithMock(t *testing.T) (*DB, sqlmock.Sqlmock) {
t.Helper()
require := require.New(t)
db, mock, err := sqlmock.New()
require.NoError(err)
require.NoError(err)
dbw, err := OpenWith(pgDriver.New(pgDriver.Config{
Conn: db,
}))
require.NoError(err)
return dbw, mock
}
// getTestOpts - iterate the inbound TestOptions and return a struct
func getTestOpts(opt ...TestOption) testOptions {
opts := getDefaultTestOptions()
for _, o := range opt {
o(&opts)
}
return opts
}
// TestOption - how Options are passed as arguments
type TestOption func(*testOptions)
// options = how options are represented
type testOptions struct {
withDialect string
withTestDatabaseUrl string
withTestMigration func(ctx context.Context, dialect, url string) error
withTestMigrationUsingDb func(ctx context.Context, db *sql.DB) error
withTestDebug bool
}
func getDefaultTestOptions() testOptions {
return testOptions{}
}
// WithTestDialect provides a way to specify the test database dialect
func WithTestDialect(dialect string) TestOption {
return func(o *testOptions) {
o.withDialect = dialect
}
}
// WithTestMigration provides a way to specify an option func which runs a
// required database migration to initialize the database
func WithTestMigration(migrationFn func(ctx context.Context, dialect, url string) error) TestOption {
return func(o *testOptions) {
o.withTestMigration = migrationFn
}
}
// WithTestMigrationUsingDB provides a way to specify an option func which runs a
// required database migration to initialize the database using an existing open
// sql.DB
func WithTestMigrationUsingDB(migrationFn func(ctx context.Context, db *sql.DB) error) TestOption {
return func(o *testOptions) {
o.withTestMigrationUsingDb = migrationFn
}
}
// WithTestDatabaseUrl provides a way to specify an existing database for tests
func WithTestDatabaseUrl(url string) TestOption {
return func(o *testOptions) {
o.withTestDatabaseUrl = url
}
}
// TestCreateTables will create the test tables for the dbw pkg
func TestCreateTables(t *testing.T, conn *DB) {
t.Helper()
require := require.New(t)
testCtx := context.Background()
rw := New(conn)
var query string
switch conn.wrapped.Dialector.Name() {
case "sqlite":
query = testQueryCreateTablesSqlite
case "postgres":
query = testQueryCreateTablesPostgres
default:
t.Fatalf("unknown dialect: %s", conn.wrapped.Dialector.Name())
}
_, err := rw.Exec(testCtx, query, nil)
require.NoError(err)
}
func testDropTables(t *testing.T, conn *DB) {
t.Helper()
require := require.New(t)
testCtx := context.Background()
rw := New(conn)
var query string
switch conn.wrapped.Dialector.Name() {
case "sqlite":
query = testQueryDropTablesSqlite
case "postgres":
query = testQueryDropTablesPostgres
default:
t.Fatalf("unknown dialect: %s", conn.wrapped.Dialector.Name())
}
_, err := rw.Exec(testCtx, query, nil)
require.NoError(err)
}
const (
testQueryCreateTablesSqlite = `
begin;
create table if not exists db_test_user (
public_id text not null constraint db_test_user_pkey primary key,
create_time timestamp not null default current_timestamp,
update_time timestamp not null default current_timestamp,
name text unique,
phone_number text,
email text,
version int default 1
);
create trigger update_time_column_db_test_user
before update on db_test_user
for each row
when
new.public_id <> old.public_id or
new.name <> old.name or
new.phone_number <> old.phone_number or
new.email <> old.email or
new.version <> old.version
begin
update db_test_user set update_time = datetime('now','localtime') where rowid == new.rowid;
end;
create trigger immutable_columns_db_test_user
before update on db_test_user
for each row
when
new.create_time <> old.create_time
begin
select raise(abort, 'immutable column');
end;
create trigger default_create_time_column_db_test_user
before insert on db_test_user
for each row
begin
update db_test_user set create_time = datetime('now','localtime') where rowid = new.rowid;
end;
create trigger update_version_column_db_test_user
after update on db_test_user
for each row
when
new.public_id <> old.public_id or
new.name <> old.name or
new.phone_number <> old.phone_number or
new.email <> old.email
begin
update db_test_user set version = old.version + 1 where rowid = new.rowid;
end;
create table if not exists db_test_car (
public_id text constraint db_test_car_pkey primary key,
create_time timestamp not null default current_timestamp,
update_time timestamp not null default current_timestamp,
name text unique,
model text,
mpg smallint,
version int default 1
);
create trigger update_time_column_db_test_car
before update on db_test_car
for each row
when
new.public_id <> old.public_id or
new.name <> old.name or
new.model <> old.model or
new.mpg <> old.mpg or
new.version <> old.version
begin
update db_test_car set update_time = datetime('now','localtime') where rowid == new.rowid;
end;
create trigger default_create_time_column_db_test_car
before insert on db_test_car
for each row
begin
update db_test_car set create_time = datetime('now','localtime') where rowid = new.rowid;
end;
create trigger update_version_column_db_test_car
after update on db_test_car
for each row
when
new.public_id <> old.public_id or
new.name <> old.name or
new.model <> old.model or
new.mpg <> old.mpg
begin
update db_test_car set version = old.version + 1 where rowid = new.rowid;
end;
create table if not exists db_test_rental (
user_id text not null references db_test_user(public_id),
car_id text not null references db_test_car(public_id),
create_time timestamp not null default current_timestamp,
update_time timestamp not null default current_timestamp,
name text unique,
version int default 1,
constraint db_test_rental_pkey primary key(user_id, car_id)
);
create trigger update_time_column_db_test_rental
before update on db_test_rental
for each row
when
new.user_id <> old.user_id or
new.car_id <> old.car_id or
new.name <> old.name or
new.version <> old.version
begin
update db_test_rental set update_time = datetime('now','localtime') where rowid == new.rowid;
end;
create trigger immutable_columns_db_test_rental
before update on db_test_rental
for each row
when
new.create_time <> old.create_time
begin
select raise(abort, 'immutable column');
end;
create trigger default_create_time_column_db_test_rental
before insert on db_test_rental
for each row
begin
update db_test_rental set create_time = datetime('now','localtime') where rowid = new.rowid;
end;
create trigger update_version_column_db_test_rental
after update on db_test_rental
for each row
when
new.user_id <> old.user_id or
new.car_id <> old.car_id or
new.name <> old.name or
new.version <> old.version
begin
update db_test_rental set version = old.version + 1 where rowid = new.rowid;
end;
create table if not exists db_test_scooter (
private_id text constraint db_test_scooter_pkey primary key,
create_time timestamp not null default current_timestamp,
update_time timestamp not null default current_timestamp,
name text unique,
model text,
mpg smallint,
version int default 1
);
create trigger update_time_column_db_test_scooter
before update on db_test_scooter
for each row
when
new.private_id <> old.private_id or
new.name <> old.name or
new.model <> old.model or
new.mpg <> old.mpg or
new.version <> old.version
begin
update db_test_scooter set update_time = datetime('now','localtime') where rowid == new.rowid;
end;
create trigger default_create_time_column_db_test_scooter
before insert on db_test_scooter
for each row
begin
update db_test_scooter set create_time = datetime('now','localtime') where rowid = new.rowid;
end;
create trigger update_version_column_db_test_scooter
after update on db_test_scooter
for each row
when
new.private_id <> old.private_id or
new.name <> old.name or
new.model <> old.model or
new.mpg <> old.mpg
begin
update db_test_scooter set version = old.version + 1 where rowid = new.rowid;
end;
commit;
`
testQueryCreateTablesPostgres = `
begin;
create domain wt_public_id as text
check(
length(trim(value)) > 10
);
comment on domain wt_public_id is
'Random ID generated with github.com/hashicorp/go-secure-stdlib/base62';
create domain wt_private_id as text
not null
check(
length(trim(value)) > 10
);
comment on domain wt_private_id is
'Random ID generated with github.com/hashicorp/go-secure-stdlib/base62';
drop domain if exists wt_timestamp;
create domain wt_timestamp as
timestamp with time zone
default current_timestamp;
comment on domain wt_timestamp is
'Standard timestamp for all create_time and update_time columns';
create or replace function
update_time_column()
returns trigger
as $$
begin
if row(new.*) is distinct from row(old.*) then
new.update_time = now();
return new;
else
return old;
end if;
end;
$$ language plpgsql;
comment on function
update_time_column()
is
'function used in before update triggers to properly set update_time columns';
create or replace function
default_create_time()
returns trigger
as $$
begin
if new.create_time is distinct from now() then
raise warning 'create_time cannot be set to %', new.create_time;
new.create_time = now();
end if;
return new;
end;
$$ language plpgsql;
comment on function
default_create_time()
is
'function used in before insert triggers to set create_time column to now';
create domain wt_version as bigint
default 1
not null
check(
value > 0
);
comment on domain wt_version is
'standard column for row version';
-- update_version_column() will increment the version column whenever row data
-- is updated and should only be used in an update after trigger. This function
-- will overwrite any explicit updates to the version column. The function
-- accepts an optional parameter of 'private_id' for the tables primary key.
create or replace function
update_version_column()
returns trigger
as $$
begin
if pg_trigger_depth() = 1 then
if row(new.*) is distinct from row(old.*) then
if tg_nargs = 0 then
execute format('update %I set version = $1 where public_id = $2', tg_relid::regclass) using old.version+1, new.public_id;
new.version = old.version + 1;
return new;
end if;
if tg_argv[0] = 'private_id' then
execute format('update %I set version = $1 where private_id = $2', tg_relid::regclass) using old.version+1, new.private_id;
new.version = old.version + 1;
return new;
end if;
end if;
end if;
return new;
end;
$$ language plpgsql;
comment on function
update_version_column()
is
'function used in after update triggers to properly set version columns';
-- immutable_columns() will make the column names immutable which are passed as
-- parameters when the trigger is created. It raises error code 23601 which is a
-- class 23 integrity constraint violation: immutable column
create or replace function
immutable_columns()
returns trigger
as $$
declare
col_name text;
new_value text;
old_value text;
begin
foreach col_name in array tg_argv loop
execute format('SELECT $1.%I', col_name) into new_value using new;
execute format('SELECT $1.%I', col_name) into old_value using old;
if new_value is distinct from old_value then
raise exception 'immutable column: %.%', tg_table_name, col_name using
errcode = '23601',
schema = tg_table_schema,
table = tg_table_name,
column = col_name;
end if;
end loop;
return new;
end;
$$ language plpgsql;
comment on function
immutable_columns()
is
'function used in before update triggers to make columns immutable';
-- ########################################################################################
create table if not exists db_test_user (
public_id wt_public_id constraint db_test_user_pkey primary key,
create_time wt_timestamp,
update_time wt_timestamp,
name text unique,
phone_number text,
email text,
version wt_version
);
create trigger update_time_column
before
update on db_test_user
for each row execute procedure update_time_column();
-- define the immutable fields for db_test_user
create trigger immutable_columns
before
update on db_test_user
for each row execute procedure immutable_columns('create_time');
create trigger default_create_time_column
before
insert on db_test_user
for each row execute procedure default_create_time();
create trigger update_version_column
after update on db_test_user
for each row execute procedure update_version_column();
create table if not exists db_test_car (
public_id wt_public_id constraint db_test_car_pkey primary key,
create_time wt_timestamp,
update_time wt_timestamp,
name text unique,
model text,
mpg smallint
);
create trigger update_time_column
before
update on db_test_car
for each row execute procedure update_time_column();
-- define the immutable fields for db_test_car
create trigger immutable_columns
before
update on db_test_car
for each row execute procedure immutable_columns('create_time');
create trigger default_create_time_column
before
insert on db_test_car
for each row execute procedure default_create_time();
create table if not exists db_test_rental (
user_id wt_public_id not null references db_test_user(public_id),
car_id wt_public_id not null references db_test_car(public_id),
create_time wt_timestamp,
update_time wt_timestamp,
name text unique,
version wt_version,
constraint db_test_rental_pkey primary key(user_id, car_id)
);
create trigger update_time_column
before
update on db_test_rental
for each row execute procedure update_time_column();
-- define the immutable fields for db_test_rental
create trigger immutable_columns
before
update on db_test_rental
for each row execute procedure immutable_columns('create_time');
create trigger default_create_time_column
before
insert on db_test_rental
for each row execute procedure default_create_time();
-- update_version_column() will increment the version column whenever row data
-- is updated and should only be used in an update after trigger. This function
-- will overwrite any explicit updates to the version column. The function
-- accepts an optional parameter of 'private_id' for the tables primary key.
create or replace function
update_rental_version_column()
returns trigger
as $$
begin
if pg_trigger_depth() = 1 then
if row(new.*) is distinct from row(old.*) then
if tg_nargs = 0 then
execute format('update %I set version = $1 where user_id = $2 and car_id = $3', tg_relid::regclass) using old.version+1, new.user_id, new.car_id;
new.version = old.version + 1;
return new;
end if;
end if;
end if;
return new;
end;
$$ language plpgsql;
create trigger update_version_column
after update on db_test_rental
for each row execute procedure update_rental_version_column();
create table if not exists db_test_scooter (
private_id wt_private_id constraint db_test_scooter_pkey primary key,
create_time wt_timestamp,
update_time wt_timestamp,
name text unique,
model text,
mpg smallint
);
create trigger update_time_column
before update on db_test_scooter
for each row execute procedure update_time_column();
-- define the immutable fields for db_test_scooter
create trigger immutable_columns
before update on db_test_scooter
for each row execute procedure immutable_columns('create_time');
create trigger default_create_time_column
before insert on db_test_scooter
for each row execute procedure default_create_time();
commit;
`
testQueryDropTablesSqlite = `
begin;
drop table if exists db_test_user;
drop table if exists db_test_car;
drop table if exists db_test_rental;
drop table if exists db_test_scooter;
commit;
`
testQueryDropTablesPostgres = `
begin;
drop table if exists db_test_user cascade;
drop table if exists db_test_car cascade;
drop table if exists db_test_rental cascade;
drop table if exists db_test_scooter cascade;
drop domain if exists wt_public_id;
drop domain if exists wt_private_id;
drop domain if exists wt_timestamp;
drop domain if exists wt_version;
commit;
`
)
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package dbw
import (
"context"
"fmt"
)
// Begin will start a transaction
func (rw *RW) Begin(ctx context.Context) (*RW, error) {
const op = "dbw.Begin"
newTx := rw.underlying.wrapped.WithContext(ctx)
newTx = newTx.Begin()
if newTx.Error != nil {
return nil, fmt.Errorf("%s: %w", op, newTx.Error)
}
return New(
&DB{wrapped: newTx},
), nil
}
// Rollback will rollback the current transaction
func (rw *RW) Rollback(ctx context.Context) error {
const op = "dbw.Rollback"
db := rw.underlying.wrapped.WithContext(ctx)
if err := db.Rollback().Error; err != nil {
return fmt.Errorf("%s: %w", op, err)
}
return nil
}
// Commit will commit a transaction
func (rw *RW) Commit(ctx context.Context) error {
const op = "dbw.Commit"
db := rw.underlying.wrapped.WithContext(ctx)
if err := db.Commit().Error; err != nil {
return fmt.Errorf("%s: %w", op, err)
}
return nil
}
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package dbw
import (
"context"
"fmt"
"reflect"
"sync/atomic"
"gorm.io/gorm"
)
var nonUpdateFields atomic.Value
// InitNonUpdatableFields sets the fields which are not updatable using
// via RW.Update(...)
func InitNonUpdatableFields(fields []string) {
m := make(map[string]struct{}, len(fields))
for _, f := range fields {
m[f] = struct{}{}
}
nonUpdateFields.Store(m)
}
// NonUpdatableFields returns the current set of fields which are not updatable using
// via RW.Update(...)
func NonUpdatableFields() []string {
m := nonUpdateFields.Load()
if m == nil {
return []string{}
}
fields := make([]string, 0, len(m.(map[string]struct{})))
for f := range m.(map[string]struct{}) {
fields = append(fields, f)
}
return fields
}
// Update a resource in the db, a fieldMask is required and provides
// field_mask.proto paths for fields that should be updated. The i interface
// parameter is the type the caller wants to update in the db and its fields are
// set to the update values. setToNullPaths is optional and provides
// field_mask.proto paths for the fields that should be set to null.
// fieldMaskPaths and setToNullPaths must not intersect. The caller is
// responsible for the transaction life cycle of the writer and if an error is
// returned the caller must decide what to do with the transaction, which almost
// always should be to rollback. Update returns the number of rows updated.
//
// Supported options: WithBeforeWrite, WithAfterWrite, WithWhere, WithDebug,
// WithTable and WithVersion. If WithVersion is used, then the update will
// include the version number in the update where clause, which basically makes
// the update use optimistic locking and the update will only succeed if the
// existing rows version matches the WithVersion option. Zero is not a valid
// value for the WithVersion option and will return an error. WithWhere allows
// specifying an additional constraint on the operation in addition to the PKs.
// WithDebug will turn on debugging for the update call.
func (rw *RW) Update(ctx context.Context, i interface{}, fieldMaskPaths []string, setToNullPaths []string, opt ...Option) (int, error) {
const op = "dbw.Update"
if rw.underlying == nil {
return noRowsAffected, fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter)
}
if isNil(i) {
return noRowsAffected, fmt.Errorf("%s: missing interface: %w", op, ErrInvalidParameter)
}
if err := raiseErrorOnHooks(i); err != nil {
return noRowsAffected, fmt.Errorf("%s: %w", op, err)
}
if len(fieldMaskPaths) == 0 && len(setToNullPaths) == 0 {
return noRowsAffected, fmt.Errorf("%s: both fieldMaskPaths and setToNullPaths are missing: %w", op, ErrInvalidParameter)
}
opts := GetOpts(opt...)
// we need to filter out some non-updatable fields (like: CreateTime, etc)
fieldMaskPaths = filterPaths(fieldMaskPaths)
setToNullPaths = filterPaths(setToNullPaths)
if len(fieldMaskPaths) == 0 && len(setToNullPaths) == 0 {
return noRowsAffected, fmt.Errorf("%s: after filtering non-updated fields, there are no fields left in fieldMaskPaths or setToNullPaths: %w", op, ErrInvalidParameter)
}
updateFields, err := UpdateFields(i, fieldMaskPaths, setToNullPaths)
if err != nil {
return noRowsAffected, fmt.Errorf("%s: getting update fields failed: %w", op, err)
}
if len(updateFields) == 0 {
return noRowsAffected, fmt.Errorf("%s: no fields matched using fieldMaskPaths %s: %w", op, fieldMaskPaths, ErrInvalidParameter)
}
names, isZero, err := rw.primaryFieldsAreZero(ctx, i)
if err != nil {
return noRowsAffected, fmt.Errorf("%s: %w", op, err)
}
if isZero {
return noRowsAffected, fmt.Errorf("%s: primary key is not set for: %s: %w", op, names, ErrInvalidParameter)
}
mDb := rw.underlying.wrapped.Model(i)
err = mDb.Statement.Parse(i)
if err != nil || mDb.Statement.Schema == nil {
return noRowsAffected, fmt.Errorf("%s: internal error: unable to parse stmt: %w", op, err)
}
reflectValue := reflect.Indirect(reflect.ValueOf(i))
for _, pf := range mDb.Statement.Schema.PrimaryFields {
if _, isZero := pf.ValueOf(ctx, reflectValue); isZero {
return noRowsAffected, fmt.Errorf("%s: primary key %s is not set: %w", op, pf.Name, ErrInvalidParameter)
}
if contains(fieldMaskPaths, pf.Name) {
return noRowsAffected, fmt.Errorf("%s: not allowed on primary key field %s: %w", op, pf.Name, ErrInvalidFieldMask)
}
}
if !opts.WithSkipVetForWrite {
if vetter, ok := i.(VetForWriter); ok {
if err := vetter.VetForWrite(ctx, rw, UpdateOp, WithFieldMaskPaths(fieldMaskPaths), WithNullPaths(setToNullPaths)); err != nil {
return noRowsAffected, fmt.Errorf("%s: %w", op, err)
}
}
}
if opts.WithBeforeWrite != nil {
if err := opts.WithBeforeWrite(i); err != nil {
return noRowsAffected, fmt.Errorf("%s: error before write: %w", op, err)
}
}
underlying := rw.underlying.wrapped.Model(i)
if opts.WithDebug {
underlying = underlying.Debug()
}
if opts.WithTable != "" {
underlying = underlying.Table(opts.WithTable)
}
switch {
case opts.WithVersion != nil || opts.WithWhereClause != "":
where, args, err := rw.whereClausesFromOpts(ctx, i, opts)
if err != nil {
return noRowsAffected, fmt.Errorf("%s: %w", op, err)
}
underlying = underlying.Where(where, args...).Updates(updateFields)
default:
underlying = underlying.Updates(updateFields)
}
if underlying.Error != nil {
if underlying.Error == gorm.ErrRecordNotFound {
return noRowsAffected, fmt.Errorf("%s: %w", op, gorm.ErrRecordNotFound)
}
return noRowsAffected, fmt.Errorf("%s: %w", op, underlying.Error)
}
rowsUpdated := int(underlying.RowsAffected)
if rowsUpdated > 0 && (opts.WithAfterWrite != nil) {
if err := opts.WithAfterWrite(i, rowsUpdated); err != nil {
return rowsUpdated, fmt.Errorf("%s: error after write: %w", op, err)
}
}
// we need to force a lookupAfterWrite so the resource returned is correctly initialized
// from the db
opt = append(opt, WithLookup(true))
if err := rw.lookupAfterWrite(ctx, i, opt...); err != nil {
return noRowsAffected, fmt.Errorf("%s: %w", op, err)
}
return rowsUpdated, nil
}
// filterPaths will filter out non-updatable fields
func filterPaths(paths []string) []string {
if len(paths) == 0 {
return nil
}
nonUpdatable := NonUpdatableFields()
if len(nonUpdatable) == 0 {
return paths
}
var filtered []string
for _, p := range paths {
switch {
case contains(nonUpdatable, p):
continue
default:
filtered = append(filtered, p)
}
}
return filtered
}