package sqlx
import (
"database/sql/driver"
"errors"
"reflect"
"strconv"
"strings"
"sync"
"github.com/muir/sqltoken"
"github.com/vinovest/sqlx/reflectx"
)
// Bindvar types supported by Rebind, BindMap and BindStruct.
const (
UNKNOWN = iota
QUESTION
DOLLAR
NAMED
AT
)
var defaultBinds = map[int][]string{
DOLLAR: {"postgres", "pgx", "pq-timeouts", "cloudsqlpostgres", "ql", "nrpostgres", "cockroach"},
QUESTION: {"mysql", "sqlite3", "nrmysql", "nrsqlite3"},
NAMED: {"oci8", "ora", "goracle", "godror"},
AT: {"sqlserver", "azuresql"},
}
var binds sync.Map
var rebindConfigs = func() []sqltoken.Config {
configs := make([]sqltoken.Config, AT+1)
pg := sqltoken.PostgreSQLConfig()
pg.NoticeQuestionMark = true
pg.NoticeDollarNumber = false
pg.SeparatePunctuation = true
configs[DOLLAR] = pg
ora := sqltoken.OracleConfig()
ora.NoticeColonWord = false
ora.NoticeQuestionMark = true
ora.SeparatePunctuation = true
configs[NAMED] = ora
ssvr := sqltoken.SQLServerConfig()
ssvr.NoticeAtWord = false
ssvr.NoticeQuestionMark = true
ssvr.SeparatePunctuation = true
configs[AT] = ssvr
return configs
}()
func init() {
for bind, drivers := range defaultBinds {
for _, driver := range drivers {
BindDriver(driver, bind)
}
}
}
// BindType returns the bindtype for a given database given a drivername.
func BindType(driverName string) int {
itype, ok := binds.Load(driverName)
if !ok {
return UNKNOWN
}
return itype.(int)
}
// BindDriver sets the BindType for driverName to bindType.
func BindDriver(driverName string, bindType int) {
binds.Store(driverName, bindType)
}
// Rebind a query from the default bindtype (QUESTION) to the target bindtype.
func Rebind(bindType int, query string) string {
switch bindType {
case QUESTION, UNKNOWN:
return query
}
config := rebindConfigs[bindType]
tokens := sqltoken.Tokenize(query, config)
rqb := make([]byte, 0, len(query)+10)
var j int
for _, token := range tokens {
if token.Type != sqltoken.QuestionMark {
rqb = append(rqb, ([]byte)(token.Text)...)
continue
}
switch bindType {
case DOLLAR:
rqb = append(rqb, '$')
case NAMED:
rqb = append(rqb, ':', 'a', 'r', 'g')
case AT:
rqb = append(rqb, '@', 'p')
}
j++
rqb = strconv.AppendInt(rqb, int64(j), 10)
}
return string(rqb)
}
func asSliceForIn(i interface{}) (v reflect.Value, ok bool) {
if i == nil {
return reflect.Value{}, false
}
v = reflect.ValueOf(i)
t := reflectx.Deref(v.Type())
// Only expand slices
if t.Kind() != reflect.Slice {
return reflect.Value{}, false
}
// []byte is a driver.Value type so it should not be expanded
if t == reflect.TypeOf([]byte{}) {
return reflect.Value{}, false
}
return v, true
}
// In expands slice values in args, returning the modified query string
// and a new arg list that can be executed by a database. The `query` should
// use the `?` bindVar. The return value uses the `?` bindVar.
func In(query string, args ...interface{}) (string, []interface{}, error) {
// argMeta stores reflect.Value and length for slices and
// the value itself for non-slice arguments
type argMeta struct {
v reflect.Value
i interface{}
length int
}
var flatArgsCount int
var anySlices bool
var stackMeta [32]argMeta
var meta []argMeta
if len(args) <= len(stackMeta) {
meta = stackMeta[:len(args)]
} else {
meta = make([]argMeta, len(args))
}
for i, arg := range args {
if a, ok := arg.(driver.Valuer); ok {
var err error
arg, err = callValuerValue(a)
if err != nil {
return "", nil, err
}
}
if v, ok := asSliceForIn(arg); ok {
meta[i].length = v.Len()
meta[i].v = v
anySlices = true
flatArgsCount += meta[i].length
if meta[i].length == 0 {
return "", nil, errors.New("empty slice passed to 'in' query")
}
} else {
meta[i].i = arg
flatArgsCount++
}
}
// don't do any parsing if there aren't any slices; note that this means
// some errors that we might have caught below will not be returned.
if !anySlices {
return query, args, nil
}
newArgs := make([]interface{}, 0, flatArgsCount)
var buf strings.Builder
buf.Grow(len(query) + len(", ?")*flatArgsCount)
var arg int
config := rebindConfigs[DOLLAR] // specific config doesn't matter, we just need the tokenizer to return QuestionMarks
tokens := sqltoken.Tokenize(query, config)
inIn := false // found `in (`
for pos, token := range tokens {
if !inIn && token.Type == sqltoken.Punctuation && token.Text == "(" {
// look backwards to see if the previous word is "in"
for i := pos - 1; i >= 0; i-- {
if tokens[i].Type == sqltoken.Word {
inIn = strings.ToLower(tokens[i].Text) == "in"
break
}
}
}
if token.Type == sqltoken.QuestionMark {
if arg >= len(meta) {
// if an argument wasn't passed, lets return an error; this is
// not actually how database/sql Exec/Query works, but since we are
// creating an argument list programmatically, we want to be able
// to catch these programmer errors earlier.
return "", nil, errors.New("number of bindVars exceeds arguments")
}
argMeta := meta[arg]
arg++
// not an in-list
if !inIn {
newArgs = append(newArgs, argMeta.i)
buf.WriteString(token.Text)
continue
}
buf.WriteString("?")
for si := 1; si < argMeta.length; si++ {
buf.WriteString(", ?")
}
newArgs = appendReflectSlice(newArgs, argMeta.v, argMeta.length)
} else if inIn && token.Type == sqltoken.Punctuation && token.Text == ")" {
inIn = false
buf.WriteString(token.Text)
} else {
buf.WriteString(token.Text)
}
}
if arg < len(meta) {
return "", nil, errors.New("number of bindVars less than number arguments")
}
return buf.String(), newArgs, nil
}
func appendReflectSlice(args []interface{}, v reflect.Value, vlen int) []interface{} {
switch val := v.Interface().(type) {
case []interface{}:
args = append(args, val...)
case []int:
for i := range val {
args = append(args, val[i])
}
case []string:
for i := range val {
args = append(args, val[i])
}
default:
for si := 0; si < vlen; si++ {
args = append(args, v.Index(si).Interface())
}
}
return args
}
// callValuerValue returns vr.Value(), with one exception:
// If vr.Value is an auto-generated method on a pointer type and the
// pointer is nil, it would panic at runtime in the panicwrap
// method. Treat it like nil instead.
// Issue 8415.
//
// This is so people can implement driver.Value on value types and
// still use nil pointers to those types to mean nil/NULL, just like
// string/*string.
//
// This function is copied from database/sql/driver package
// and mirrored in the database/sql package.
func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Pointer &&
rv.IsNil() &&
rv.Type().Elem().Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) {
return nil, nil
}
return vr.Value()
}
package sqlx
// Named Query Support
//
// * BindMap - bind query bindvars to map/struct args
// * NamedExec, NamedQuery - named query w/ struct or map
// * NamedStmt - a pre-compiled named query which is a prepared statement
//
// Internal Interfaces:
//
// * compileNamedQuery - rebind a named query, returning a query and list of names
// * bindArgs, bindMapArgs, bindAnyArgs - given a list of names, return an arglist
//
import (
"bytes"
"database/sql"
"fmt"
"iter"
"reflect"
"strconv"
"strings"
"github.com/muir/sqltoken"
"github.com/vinovest/sqlx/reflectx"
)
// GenericNamedStmt is a prepared statement that executes named queries. Prepare it
// how you would execute a NamedQuery, but pass in a struct or map when executing. This
// is a generic version of NamedStmt. To preserve user code compatibility.
type GenericNamedStmt[T any] struct {
Params []string
QueryString string
Stmt *GenericStmt[T]
}
// NamedStmt is a prepared statement that executes named queries. Prepare it
// how you would execute a NamedQuery, but pass in a struct or map when executing.
type NamedStmt = GenericNamedStmt[any]
// Close closes the named statement.
func (n *GenericNamedStmt[T]) Close() error {
return n.Stmt.Close()
}
// Exec executes a named statement using the struct passed.
// Any named placeholder parameters are replaced with fields from arg.
func (n *GenericNamedStmt[T]) Exec(arg interface{}) (sql.Result, error) {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return *new(sql.Result), err
}
return n.Stmt.Exec(args...)
}
// Query executes a named statement using the struct argument, returning rows.
// Any named placeholder parameters are replaced with fields from arg.
func (n *GenericNamedStmt[T]) Query(arg interface{}) (*sql.Rows, error) {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return nil, err
}
return n.Stmt.Query(args...)
}
// QueryRow executes a named statement against the database. Because sqlx cannot
// create a *sql.Row with an error condition pre-set for binding errors, sqlx
// returns a *sqlx.Row instead.
// Any named placeholder parameters are replaced with fields from arg.
func (n *GenericNamedStmt[T]) QueryRow(arg interface{}) *Row {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return &Row{err: err}
}
return n.Stmt.QueryRowx(args...)
}
// MustExec execs a NamedStmt, panicing on error
// Any named placeholder parameters are replaced with fields from arg.
func (n *GenericNamedStmt[T]) MustExec(arg interface{}) sql.Result {
res, err := n.Exec(arg)
if err != nil {
panic(err)
}
return res
}
// Queryx using this NamedStmt
// Any named placeholder parameters are replaced with fields from arg.
func (n *GenericNamedStmt[T]) Queryx(arg interface{}) (*Rows, error) {
r, err := n.Query(arg)
if err != nil {
return nil, err
}
return &Rows{Rows: r, Mapper: n.Stmt.Mapper, options: n.Stmt.options}, err
}
// QueryRowx this NamedStmt. Because of limitations with QueryRow, this is
// an alias for QueryRow.
// Any named placeholder parameters are replaced with fields from arg.
func (n *GenericNamedStmt[T]) QueryRowx(arg interface{}) *Row {
return n.QueryRow(arg)
}
// Select using this NamedStmt
// Any named placeholder parameters are replaced with fields from arg.
func (n *GenericNamedStmt[T]) Select(dest interface{}, arg interface{}) error {
rows, err := n.Queryx(arg)
if err != nil {
return err
}
// if something happens here, we want to make sure the rows are Closed
defer rows.Close()
return scanAll(rows, dest, false)
}
// List performs a query using the statement and returns all rows as a slice of T.
func (n *GenericNamedStmt[T]) List(arg interface{}) ([]T, error) {
var dests []T
err := n.Select(&dests, arg)
return dests, err
}
// Get using this NamedStmt
// Any named placeholder parameters are replaced with fields from arg.
func (n *GenericNamedStmt[T]) Get(dest interface{}, arg interface{}) error {
r := n.QueryRowx(arg)
return r.scanAny(dest, false)
}
// One get a single row using this NamedStmt
// Any named placeholder parameters are replaced with fields from arg.
func (n *GenericNamedStmt[T]) One(arg interface{}) (T, error) {
r := n.QueryRowx(arg)
var dest T
err := r.scanAny(&dest, false)
return dest, err
}
// All performs a query using the GenericNamedStmt and returns all rows for use with range.
func (n *GenericNamedStmt[T]) All(arg interface{}) iter.Seq2[T, error] {
rows, err := n.Queryx(arg)
if err != nil {
panic(err)
}
return func(yield func(T, error) bool) {
defer func(rows *Rows) {
_ = rows.Close()
}(rows)
for rows.Next() {
var dest T
err := rows.StructScan(&dest)
if !yield(dest, err) {
return
}
}
}
}
// Prepare returns a transaction-specific prepared statement from
// an existing statement.
//
// The returned statement operates within the transaction and will be closed
// when the transaction has been committed or rolled back (you do not need to close it).
func (n *GenericNamedStmt[T]) Prepare(ndb Queryable) *GenericNamedStmt[T] {
tx, ok := ndb.(*Tx)
if !ok {
// not needed
return n
}
return &GenericNamedStmt[T]{
Params: n.Params,
QueryString: n.QueryString,
Stmt: &GenericStmt[T]{
Stmt: tx.Stmt(n.Stmt.Stmt),
options: n.Stmt.options,
Mapper: n.Stmt.Mapper,
},
}
}
// Unsafe creates an unsafe version of the GenericNamedStmt
func (n *GenericNamedStmt[T]) Unsafe() *GenericNamedStmt[T] {
stmt := n.Stmt.Unsafe()
r := &GenericNamedStmt[T]{Params: n.Params, Stmt: stmt, QueryString: n.QueryString}
return r
}
// getOptions work around type assertions with generics
func (n *GenericNamedStmt[T]) getOptions() *dbOptions {
return n.Stmt.options
}
// A union interface of preparer and binder, required to be able to prepare
// named statements (as the bindtype must be determined).
type namedPreparer interface {
Preparer
binder
}
func PrepareNamed[T any](p namedPreparer, query string) (*GenericNamedStmt[T], error) {
bindType := BindType(p.DriverName())
compiled, err := compileNamedQuery([]byte(query), bindType)
if err != nil {
return nil, err
}
stmt, err := Preparex[T](p, compiled.query)
if err != nil {
return nil, err
}
return &GenericNamedStmt[T]{
QueryString: compiled.query,
Params: compiled.names,
Stmt: stmt,
}, nil
}
// convertMapStringInterface attempts to convert v to map[string]interface{}.
// Unlike v.(map[string]interface{}), this function works on named types that
// are convertible to map[string]interface{} as well.
func convertMapStringInterface(v interface{}) (map[string]interface{}, bool) {
var m map[string]interface{}
mtype := reflect.TypeOf(m)
t := reflect.TypeOf(v)
if !t.ConvertibleTo(mtype) {
return nil, false
}
return reflect.ValueOf(v).Convert(mtype).Interface().(map[string]interface{}), true
}
func bindAnyArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) {
if maparg, ok := convertMapStringInterface(arg); ok {
return bindMapArgs(names, maparg)
}
return bindArgs(names, arg, m)
}
// private interface to generate a list of interfaces from a given struct
// type, given a list of names to pull out of the struct. Used by public
// BindStruct interface.
func bindArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) {
arglist := make([]interface{}, 0, len(names))
// grab the indirected value of arg
var v reflect.Value
for v = reflect.ValueOf(arg); v.Kind() == reflect.Ptr; {
v = v.Elem()
}
err := m.TraversalsByNameFunc(v.Type(), names, func(i int, t []int) error {
if len(t) == 0 {
return fmt.Errorf("could not find name %s in %#v", names[i], arg)
}
val := reflectx.FieldByIndexesReadOnly(v, t)
arglist = append(arglist, val.Interface())
return nil
})
return arglist, err
}
// like bindArgs, but for maps.
func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, error) {
arglist := make([]interface{}, 0, len(names))
for _, name := range names {
val, ok := arg[name]
if !ok {
return arglist, fmt.Errorf("could not find name %s in %#v", name, arg)
}
arglist = append(arglist, val)
}
return arglist, nil
}
// bindStruct binds a named parameter query with fields from a struct argument.
// The rules for binding field names to parameter names follow the same
// conventions as for StructScan, including obeying the `db` struct tags.
func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) {
compiled, err := compileNamedQuery([]byte(query), bindType)
if err != nil {
return "", []interface{}{}, err
}
arglist, err := bindAnyArgs(compiled.names, arg, m)
if err != nil {
return "", []interface{}{}, err
}
return compiled.query, arglist, nil
}
func fixBound(cq *compiledQueryResult, loop int) {
if cq.valuesStart == nil || cq.valuesEnd == nil {
return
}
buffer := bytes.NewBuffer(make([]byte, 0, (int(*cq.valuesEnd-1)-int(*cq.valuesStart))*loop+
// bytes for commas, too
(loop-1)+
// plus the query
(len(cq.query)-int(*cq.valuesEnd)+int(*cq.valuesStart))))
buffer.WriteString(cq.query[0:*cq.valuesEnd])
for i := 0; i < loop-1; i++ {
buffer.WriteString(",")
buffer.WriteString(cq.query[*cq.valuesStart:*cq.valuesEnd])
}
buffer.WriteString(cq.query[*cq.valuesEnd:])
cq.query = buffer.String()
}
// bindArray binds a named parameter query with fields from an array or slice of
// structs argument.
func bindArray(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) {
// do the initial binding with QUESTION; if bindType is not question,
// we can rebind it at the end.
compiled, err := compileNamedQuery([]byte(query), QUESTION)
if err != nil {
return "", []interface{}{}, err
}
arrayValue := reflect.ValueOf(arg)
arrayLen := arrayValue.Len()
if arrayLen == 0 {
return "", []interface{}{}, fmt.Errorf("length of array is 0: %#v", arg)
}
arglist := make([]interface{}, 0, len(compiled.names)*arrayLen)
for i := 0; i < arrayLen; i++ {
elemArglist, err := bindAnyArgs(compiled.names, arrayValue.Index(i).Interface(), m)
if err != nil {
return "", []interface{}{}, err
}
arglist = append(arglist, elemArglist...)
}
if arrayLen > 1 {
fixBound(&compiled, arrayLen)
}
// adjust binding type if we weren't on question
bound := compiled.query
if bindType != QUESTION {
bound = Rebind(bindType, bound)
}
return bound, arglist, nil
}
// bindMap binds a named parameter query with a map of arguments.
func bindMap(bindType int, query string, args map[string]interface{}) (string, []interface{}, error) {
compiled, err := compileNamedQuery([]byte(query), bindType)
if err != nil {
return "", []interface{}{}, err
}
arglist, err := bindMapArgs(compiled.names, args)
return compiled.query, arglist, err
}
var namedParseConfigs = func() []sqltoken.Config {
configs := make([]sqltoken.Config, AT+1)
pg := sqltoken.PostgreSQLConfig()
pg.NoticeColonWord = true
pg.ColonWordIncludesUnicode = true
pg.NoticeDollarNumber = false
pg.NoticeQuestionMark = true
pg.SeparatePunctuation = true
configs[DOLLAR] = pg
ora := sqltoken.OracleConfig()
ora.ColonWordIncludesUnicode = true
ora.NoticeQuestionMark = true
ora.SeparatePunctuation = true
configs[NAMED] = ora
ssvr := sqltoken.SQLServerConfig()
ssvr.NoticeColonWord = true
ssvr.ColonWordIncludesUnicode = true
ssvr.NoticeAtWord = false
ssvr.SeparatePunctuation = true
configs[AT] = ssvr
mysql := sqltoken.MySQLConfig()
mysql.NoticeColonWord = true
mysql.ColonWordIncludesUnicode = true
mysql.NoticeQuestionMark = true
mysql.SeparatePunctuation = true
configs[QUESTION] = mysql
configs[UNKNOWN] = mysql
return configs
}()
// -- Compilation of Named Queries
// FIXME: this function isn't safe for unicode named params, as a failing test
// can testify. This is not a regression but a failure of the original code
// as well. It should be modified to range over runes in a string rather than
// bytes, even though this is less convenient and slower. Hopefully the
// addition of the prepared NamedStmt (which will only do this once) will make
// up for the slightly slower ad-hoc NamedExec/NamedQuery.
type compiledQueryResult struct {
// the query string with the named parameters swapped out with bindvars
query string
// the name of the parameter
names []string
// if set, the start position of the VALUES argument list not including () (end inclusive)
valuesStart *uint32
valuesEnd *uint32
}
// compile a NamedQuery into an unbound query (using the '?' bindvar) and
// a list of names.
func compileNamedQuery(qs []byte, bindType int) (compiledQueryResult, error) {
r := compiledQueryResult{
names: make([]string, 0, 10),
}
curpos := uint32(0)
rebound := make([]byte, 0, len(qs))
inValues := false
inValuesOpenCount := 0
currentVar := 1
tokens := sqltoken.Tokenize(string(qs), namedParseConfigs[bindType])
for _, token := range tokens {
if token.Type == sqltoken.Word && strings.EqualFold("values", token.Text) && !inValues {
// current behavior: expand the first values and ignore the rest
if r.valuesStart == nil { // did we already parse a values statement?
inValues = true
}
}
if inValues && token.Type == sqltoken.Punctuation {
if token.Text == "(" {
start := curpos
inValuesOpenCount += 1
if r.valuesStart == nil {
r.valuesStart = &start
}
}
if token.Text == ")" {
inValuesOpenCount -= 1
if inValuesOpenCount == 0 {
end := curpos + 1
r.valuesEnd = &end
inValues = false
}
}
}
if token.Type != sqltoken.ColonWord {
rebound = append(rebound, ([]byte)(token.Text)...)
curpos += uint32(len(token.Text))
continue
}
r.names = append(r.names, token.Text[1:])
newBound := ""
switch bindType {
// oracle only supports named type bind vars even for positional
case NAMED:
newBound = token.Text
case QUESTION, UNKNOWN:
newBound = "?"
case DOLLAR:
newBound = "$" + strconv.Itoa(currentVar)
currentVar++
case AT:
newBound = "@p" + strconv.Itoa(currentVar)
currentVar++
}
rebound = append(rebound, []byte(newBound)...)
curpos += uint32(len(newBound))
}
if inValues {
return r, fmt.Errorf("missing closing bracket in VALUES")
}
r.query = string(rebound)
return r, nil
}
// BindNamed binds a struct or a map to a query with named parameters.
// DEPRECATED: use sqlx.Named` instead of this, it may be removed in future.
func BindNamed(bindType int, query string, arg interface{}) (string, []interface{}, error) {
return bindNamedMapper(bindType, query, arg, mapper())
}
// Named takes a query using named parameters and an argument and
// returns a new query with a list of args that can be executed by
// a database. The return value uses the `?` bindvar.
func Named(query string, arg interface{}) (string, []interface{}, error) {
return bindNamedMapper(QUESTION, query, arg, mapper())
}
func bindNamedMapper(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) {
t := reflect.TypeOf(arg)
k := t.Kind()
switch {
case k == reflect.Map && t.Key().Kind() == reflect.String:
m, ok := convertMapStringInterface(arg)
if !ok {
return "", nil, fmt.Errorf("sqlx.bindNamedMapper: unsupported map type: %T", arg)
}
return bindMap(bindType, query, m)
case k == reflect.Array || k == reflect.Slice:
return bindArray(bindType, query, arg, m)
default:
return bindStruct(bindType, query, arg, m)
}
}
// NamedQuery binds a named query and then runs Query on the result using the
// provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with
// map[string]interface{} types.
func NamedQuery(e Ext, query string, arg interface{}) (*Rows, error) {
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
if err != nil {
return nil, err
}
return e.Queryx(q, args...)
}
// NamedExec uses BindStruct to get a query executable by the driver and
// then runs Exec on the result. Returns an error from the binding
// or the query execution itself.
func NamedExec(e Ext, query string, arg interface{}) (sql.Result, error) {
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
if err != nil {
return nil, err
}
return e.Exec(q, args...)
}
package sqlx
import (
"context"
"database/sql"
"iter"
)
// A union interface of contextPreparer and binder, required to be able to
// prepare named statements with context (as the bindtype must be determined).
type namedPreparerContext interface {
PreparerContext
binder
}
// PrepareNamedContext prepares a named statement for use on the database. Use `PrepareContext` on
// the statement to ready a prepared statement to be used in a transaction.
//
// The returned statement operates within the transaction and will be closed
// when the transaction has been committed or rolled back (you do not need to close it).
func PrepareNamedContext[T any](ctx context.Context, p namedPreparerContext, query string) (*GenericNamedStmt[T], error) {
bindType := BindType(p.DriverName())
compiled, err := compileNamedQuery([]byte(query), bindType)
if err != nil {
return nil, err
}
stmt, err := PreparexContext[T](ctx, p, compiled.query)
if err != nil {
return nil, err
}
return &GenericNamedStmt[T]{
QueryString: compiled.query,
Params: compiled.names,
Stmt: stmt,
}, nil
}
// PrepareContext returns a transaction-specific prepared statement from
// an existing statement.
//
// It's preferred to use this method over `Prepare` (without context) due to go internals, it
// uses the connection found in context.
//
// The returned statement operates within the transaction and will be closed
// when the transaction has been committed or rolled back (you do not need to close it).
func (n *GenericNamedStmt[T]) PrepareContext(ctx context.Context, ndb Queryable) *GenericNamedStmt[T] {
tx, ok := ndb.(*Tx)
if !ok {
// not needed
return n
}
return &GenericNamedStmt[T]{
Params: n.Params,
QueryString: n.QueryString,
Stmt: &GenericStmt[T]{
Stmt: tx.StmtContext(ctx, n.Stmt.Stmt),
options: n.Stmt.options,
Mapper: n.Stmt.Mapper,
},
}
}
// ExecContext executes a named statement using the struct passed.
// Any named placeholder parameters are replaced with fields from arg.
func (n *GenericNamedStmt[T]) ExecContext(ctx context.Context, arg interface{}) (sql.Result, error) {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return *new(sql.Result), err
}
return n.Stmt.ExecContext(ctx, args...)
}
// QueryContext executes a named statement using the struct argument, returning rows.
// Any named placeholder parameters are replaced with fields from arg.
func (n *GenericNamedStmt[T]) QueryContext(ctx context.Context, arg interface{}) (*sql.Rows, error) {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return nil, err
}
return n.Stmt.QueryContext(ctx, args...)
}
// QueryRowContext executes a named statement against the database. Because sqlx cannot
// create a *sql.Row with an error condition pre-set for binding errors, sqlx
// returns a *sqlx.Row instead.
// Any named placeholder parameters are replaced with fields from arg.
func (n *GenericNamedStmt[T]) QueryRowContext(ctx context.Context, arg interface{}) *Row {
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return &Row{err: err}
}
return n.Stmt.QueryRowxContext(ctx, args...)
}
// MustExecContext execs a NamedStmt, panicing on error
// Any named placeholder parameters are replaced with fields from arg.
func (n *GenericNamedStmt[T]) MustExecContext(ctx context.Context, arg interface{}) sql.Result {
res, err := n.ExecContext(ctx, arg)
if err != nil {
panic(err)
}
return res
}
// QueryxContext using this NamedStmt
// Any named placeholder parameters are replaced with fields from arg.
func (n *GenericNamedStmt[T]) QueryxContext(ctx context.Context, arg interface{}) (*Rows, error) {
r, err := n.QueryContext(ctx, arg)
if err != nil {
return nil, err
}
return &Rows{Rows: r, Mapper: n.Stmt.Mapper, options: n.Stmt.options}, err
}
// QueryRowxContext this NamedStmt. Because of limitations with QueryRow, this is
// an alias for QueryRow.
// Any named placeholder parameters are replaced with fields from arg.
func (n *GenericNamedStmt[T]) QueryRowxContext(ctx context.Context, arg interface{}) *Row {
return n.QueryRowContext(ctx, arg)
}
// SelectContext using this NamedStmt
// Any named placeholder parameters are replaced with fields from arg.
func (n *GenericNamedStmt[T]) SelectContext(ctx context.Context, dest interface{}, arg interface{}) error {
rows, err := n.QueryxContext(ctx, arg)
if err != nil {
return err
}
// if something happens here, we want to make sure the rows are Closed
defer rows.Close()
return scanAll(rows, dest, false)
}
// GetContext using this NamedStmt
// Any named placeholder parameters are replaced with fields from arg.
func (n *GenericNamedStmt[T]) GetContext(ctx context.Context, dest interface{}, arg interface{}) error {
r := n.QueryRowxContext(ctx, arg)
return r.scanAny(dest, false)
}
// OneContext get a single row using this NamedStmt
// Any named placeholder parameters are replaced with fields from arg.
func (n *GenericNamedStmt[T]) OneContext(ctx context.Context, arg interface{}) (T, error) {
r := n.QueryRowxContext(ctx, arg)
var dest T
err := r.scanAny(&dest, false)
return dest, err
}
// AllContext performs a query using the NamedStmt and returns all rows for use with range.
func (n *GenericNamedStmt[T]) AllContext(ctx context.Context, arg interface{}) iter.Seq2[T, error] {
rows, err := n.QueryxContext(ctx, arg)
if err != nil {
panic(err)
}
return func(yield func(T, error) bool) {
defer func(rows *Rows) {
_ = rows.Close()
}(rows)
for rows.Next() {
if ctx.Err() != nil {
return
}
var dest T
err := rows.StructScan(&dest)
if !yield(dest, err) {
return
}
}
}
}
// NamedQueryContext binds a named query and then runs Query on the result using the
// provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with
// map[string]interface{} types.
func NamedQueryContext(ctx context.Context, e ExtContext, query string, arg interface{}) (*Rows, error) {
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
if err != nil {
return nil, err
}
return e.QueryxContext(ctx, q, args...)
}
// NamedExecContext uses BindStruct to get a query executable by the driver and
// then runs Exec on the result. Returns an error from the binding
// or the query execution itself.
func NamedExecContext(ctx context.Context, e ExtContext, query string, arg interface{}) (sql.Result, error) {
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
if err != nil {
return nil, err
}
return e.ExecContext(ctx, q, args...)
}
// Package reflectx implements extensions to the standard reflect lib suitable
// for implementing marshalling and unmarshalling packages. The main Mapper type
// allows for Go-compatible named attribute access, including accessing embedded
// struct attributes and the ability to use functions and struct tags to
// customize field names.
package reflectx
import (
"reflect"
"runtime"
"strings"
"sync"
)
// A FieldInfo is metadata for a struct field.
type FieldInfo struct {
Index []int
Path string
Field reflect.StructField
Zero reflect.Value
Name string
Options map[string]string
Embedded bool
Children []*FieldInfo
Parent *FieldInfo
}
// A StructMap is an index of field metadata for a struct.
type StructMap struct {
Tree *FieldInfo
Index []*FieldInfo
Paths map[string]*FieldInfo
Names map[string]*FieldInfo
Leafs map[string][]*FieldInfo
}
// GetByPath returns a *FieldInfo for a given string path.
func (f StructMap) GetByPath(path string) *FieldInfo {
return f.Paths[path]
}
// GetByTraversal returns a *FieldInfo for a given integer path. It is
// analogous to reflect.FieldByIndex, but using the cached traversal
// rather than re-executing the reflect machinery each time.
func (f StructMap) GetByTraversal(index []int) *FieldInfo {
if len(index) == 0 {
return nil
}
tree := f.Tree
for _, i := range index {
if i >= len(tree.Children) || tree.Children[i] == nil {
return nil
}
tree = tree.Children[i]
}
return tree
}
// Mapper is a general purpose mapper of names to struct fields. A Mapper
// behaves like most marshallers in the standard library, obeying a field tag
// for name mapping but also providing a basic transform function.
type Mapper struct {
cache map[reflect.Type]*StructMap
tagName string
tagMapFunc func(string) string
mapFunc func(string) string
mutex sync.Mutex
}
// NewMapper returns a new mapper using the tagName as its struct field tag.
// If tagName is the empty string, it is ignored.
func NewMapper(tagName string) *Mapper {
return &Mapper{
cache: make(map[reflect.Type]*StructMap),
tagName: tagName,
}
}
// NewMapperTagFunc returns a new mapper which contains a mapper for field names
// AND a mapper for tag values. This is useful for tags like json which can
// have values like "name,omitempty".
func NewMapperTagFunc(tagName string, mapFunc, tagMapFunc func(string) string) *Mapper {
return &Mapper{
cache: make(map[reflect.Type]*StructMap),
tagName: tagName,
mapFunc: mapFunc,
tagMapFunc: tagMapFunc,
}
}
// NewMapperFunc returns a new mapper which optionally obeys a field tag and
// a struct field name mapper func given by f. Tags will take precedence, but
// for any other field, the mapped name will be f(field.Name)
func NewMapperFunc(tagName string, f func(string) string) *Mapper {
return &Mapper{
cache: make(map[reflect.Type]*StructMap),
tagName: tagName,
mapFunc: f,
}
}
// TypeMap returns a mapping of field strings to int slices representing
// the traversal down the struct to reach the field.
func (m *Mapper) TypeMap(t reflect.Type) *StructMap {
m.mutex.Lock()
mapping, ok := m.cache[t]
if !ok {
mapping = getMapping(t, m.tagName, m.mapFunc, m.tagMapFunc)
m.cache[t] = mapping
}
m.mutex.Unlock()
return mapping
}
// FieldMap returns the mapper's mapping of field names to reflect values. Panics
// if v's Kind is not Struct, or v is not Indirectable to a struct kind.
func (m *Mapper) FieldMap(v reflect.Value) map[string]reflect.Value {
v = reflect.Indirect(v)
mustBe(v, reflect.Struct)
r := map[string]reflect.Value{}
tm := m.TypeMap(v.Type())
for tagName, fi := range tm.Names {
r[tagName] = FieldByIndexes(v, fi.Index)
}
return r
}
// FieldByName returns a field by its mapped name as a reflect.Value.
// Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind.
// Returns zero Value if the name is not found.
func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value {
v = reflect.Indirect(v)
mustBe(v, reflect.Struct)
tm := m.TypeMap(v.Type())
fi, ok := tm.Names[name]
if !ok {
return v
}
return FieldByIndexes(v, fi.Index)
}
// FieldsByName returns a slice of values corresponding to the slice of names
// for the value. Panics if v's Kind is not Struct or v is not Indirectable
// to a struct Kind. Returns zero Value for each name not found.
func (m *Mapper) FieldsByName(v reflect.Value, names []string) []reflect.Value {
v = reflect.Indirect(v)
mustBe(v, reflect.Struct)
tm := m.TypeMap(v.Type())
vals := make([]reflect.Value, 0, len(names))
for _, name := range names {
fi, ok := tm.Names[name]
if !ok {
vals = append(vals, *new(reflect.Value))
} else {
vals = append(vals, FieldByIndexes(v, fi.Index))
}
}
return vals
}
// TraversalsByName returns a slice of int slices which represent the struct
// traversals for each mapped name. Panics if t is not a struct or Indirectable
// to a struct. Returns empty int slice for each name not found.
func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int {
r := make([][]int, 0, len(names))
_ = m.TraversalsByNameFunc(t, names, func(_ int, i []int) error {
if i == nil {
r = append(r, []int{})
} else {
r = append(r, i)
}
return nil
})
return r
}
// TraversalsByNameFunc traverses the mapped names and calls fn with the index of
// each name and the struct traversal represented by that name. Panics if t is not
// a struct or Indirectable to a struct. Returns the first error returned by fn or nil.
func (m *Mapper) TraversalsByNameFunc(t reflect.Type, names []string, fn func(int, []int) error) error {
t = Deref(t)
mustBe(t, reflect.Struct)
tm := m.TypeMap(t)
nameCounter := make(map[string]int, len(names))
for i, name := range names {
fi, ok := tm.Names[name]
if !ok {
if leafs, lok := tm.Leafs[name]; lok {
nameCount := nameCounter[name]
if nameCount >= len(leafs) {
// don't break existing queries, before these would be assigned to the same field
if len(leafs) > 0 {
if err := fn(i, leafs[0].Index); err != nil {
return err
}
} else {
// too many, not found
if err := fn(i, nil); err != nil {
return err
}
}
} else {
fi = leafs[nameCount]
nameCounter[fi.Name]++
if err := fn(i, fi.Index); err != nil {
return err
}
}
} else {
nameCounter[name]++
if err := fn(i, nil); err != nil {
return err
}
}
} else {
nameCounter[fi.Name]++
if err := fn(i, fi.Index); err != nil {
return err
}
}
}
return nil
}
// FieldByIndexes returns a value for the field given by the struct traversal
// for the given value.
func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value {
for _, i := range indexes {
v = reflect.Indirect(v).Field(i)
// if this is a pointer and it's nil, allocate a new value and set it
if v.Kind() == reflect.Ptr && v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
if v.Kind() == reflect.Map && v.IsNil() {
v.Set(reflect.MakeMap(v.Type()))
}
}
return v
}
// FieldByIndexesReadOnly returns a value for a particular struct traversal,
// but is not concerned with allocating nil pointers because the value is
// going to be used for reading and not setting.
func FieldByIndexesReadOnly(v reflect.Value, indexes []int) reflect.Value {
for _, i := range indexes {
v = reflect.Indirect(v).Field(i)
}
return v
}
// Deref is Indirect for reflect.Types
func Deref(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t
}
// -- helpers & utilities --
type kinder interface {
Kind() reflect.Kind
}
// mustBe checks a value against a kind, panicing with a reflect.ValueError
// if the kind isn't that which is required.
func mustBe(v kinder, expected reflect.Kind) {
if k := v.Kind(); k != expected {
panic(&reflect.ValueError{Method: methodName(), Kind: k})
}
}
// methodName returns the caller of the function calling methodName
func methodName() string {
pc, _, _, _ := runtime.Caller(2)
f := runtime.FuncForPC(pc)
if f == nil {
return "unknown method"
}
return f.Name()
}
type typeQueue struct {
t reflect.Type
fi *FieldInfo
pp string // Parent path
}
// A copying append that creates a new slice each time.
func apnd(is []int, i int) []int {
x := make([]int, len(is)+1)
copy(x, is)
x[len(x)-1] = i
return x
}
type mapf func(string) string
// parseName parses the tag and the target name for the given field using
// the tagName (eg 'json' for `json:"foo"` tags), mapFunc for mapping the
// field's name to a target name, and tagMapFunc for mapping the tag to
// a target name.
func parseName(field reflect.StructField, tagName string, mapFunc, tagMapFunc mapf) (tag, fieldName string) {
// first, set the fieldName to the field's name
fieldName = field.Name
// if a mapFunc is set, use that to override the fieldName
if mapFunc != nil {
fieldName = mapFunc(fieldName)
}
// if there's no tag to look for, return the field name
if tagName == "" {
return "", fieldName
}
// if this tag is not set using the normal convention in the tag,
// then return the fieldname.. this check is done because according
// to the reflect documentation:
// If the tag does not have the conventional format,
// the value returned by Get is unspecified.
// which doesn't sound great.
if !strings.Contains(string(field.Tag), tagName+":") {
return "", fieldName
}
// at this point we're fairly sure that we have a tag, so lets pull it out
tag = field.Tag.Get(tagName)
// if we have a mapper function, call it on the whole tag
// XXX: this is a change from the old version, which pulled out the name
// before the tagMapFunc could be run, but I think this is the right way
if tagMapFunc != nil {
tag = tagMapFunc(tag)
}
// finally, split the options from the name
parts := strings.Split(tag, ",")
fieldName = parts[0]
return tag, fieldName
}
// parseOptions parses options out of a tag string, skipping the name
func parseOptions(tag string) map[string]string {
parts := strings.Split(tag, ",")
options := make(map[string]string, len(parts))
if len(parts) > 1 {
for _, opt := range parts[1:] {
// short circuit potentially expensive split op
if strings.Contains(opt, "=") {
kv := strings.Split(opt, "=")
options[kv[0]] = kv[1]
continue
}
options[opt] = ""
}
}
return options
}
// getMapping returns a mapping for the t type, using the tagName, mapFunc and
// tagMapFunc to determine the canonical names of fields.
func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc mapf) *StructMap {
m := []*FieldInfo{}
root := &FieldInfo{}
queue := []typeQueue{}
queue = append(queue, typeQueue{Deref(t), root, ""})
QueueLoop:
for len(queue) != 0 {
// pop the first item off of the queue
tq := queue[0]
queue = queue[1:]
// ignore recursive field
for p := tq.fi.Parent; p != nil; p = p.Parent {
if tq.fi.Field.Type == p.Field.Type {
continue QueueLoop
}
}
nChildren := 0
if tq.t.Kind() == reflect.Struct {
nChildren = tq.t.NumField()
}
tq.fi.Children = make([]*FieldInfo, nChildren)
// iterate through all of its fields
for fieldPos := 0; fieldPos < nChildren; fieldPos++ {
f := tq.t.Field(fieldPos)
// parse the tag and the target name using the mapping options for this field
tag, name := parseName(f, tagName, mapFunc, tagMapFunc)
// if the name is "-", disabled via a tag, skip it
if name == "-" {
continue
}
fi := FieldInfo{
Field: f,
Name: name,
Zero: reflect.New(f.Type).Elem(),
Options: parseOptions(tag),
}
// if the path is empty this path is just the name
if tq.pp == "" {
fi.Path = fi.Name
} else {
fi.Path = tq.pp + "." + fi.Name
}
// skip unexported fields
if len(f.PkgPath) != 0 && !f.Anonymous {
continue
}
// bfs search of anonymous embedded structs
if f.Anonymous {
pp := tq.pp
if tag != "" {
pp = fi.Path
}
fi.Embedded = true
fi.Index = apnd(tq.fi.Index, fieldPos)
nChildren := 0
ft := Deref(f.Type)
if ft.Kind() == reflect.Struct {
nChildren = ft.NumField()
}
fi.Children = make([]*FieldInfo, nChildren)
queue = append(queue, typeQueue{Deref(f.Type), &fi, pp})
} else if fi.Zero.Kind() == reflect.Struct || (fi.Zero.Kind() == reflect.Ptr && fi.Zero.Type().Elem().Kind() == reflect.Struct) {
fi.Index = apnd(tq.fi.Index, fieldPos)
fi.Children = make([]*FieldInfo, Deref(f.Type).NumField())
queue = append(queue, typeQueue{Deref(f.Type), &fi, fi.Path})
}
fi.Index = apnd(tq.fi.Index, fieldPos)
fi.Parent = tq.fi
tq.fi.Children[fieldPos] = &fi
m = append(m, &fi)
}
}
flds := &StructMap{
Index: m,
Tree: root,
Paths: map[string]*FieldInfo{},
Names: map[string]*FieldInfo{},
Leafs: map[string][]*FieldInfo{},
}
for _, fi := range flds.Index {
// check if nothing has already been pushed with the same path
// sometimes you can choose to override a type using embedded struct
fld, ok := flds.Paths[fi.Path]
if !ok || fld.Embedded {
flds.Paths[fi.Path] = fi
if fi.Name != "" && !fi.Embedded {
flds.Names[fi.Path] = fi
mappedName := fi.Name
if mapFunc != nil {
mappedName = mapFunc(mappedName)
}
if _, lok := flds.Leafs[mappedName]; !lok {
flds.Leafs[mappedName] = []*FieldInfo{}
}
flds.Leafs[mappedName] = append(flds.Leafs[mappedName], fi)
}
}
}
return flds
}
package sqlx
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"iter"
"os"
"path/filepath"
"reflect"
"strings"
"sync"
"github.com/vinovest/sqlx/reflectx"
)
// ErrMultiRows is returned by functions which are expected to work with result sets
// that only contain a single row but multiple rows were returned.
// This typically indicates an issue with the query such as a missing join criteria or
// limit condition or the use of Get(...) when Select(...) was intended.
var ErrMultiRows = errors.New("sql: multiple rows returned")
// Although the NameMapper is convenient, in practice it should not
// be relied on except for application code. If you are writing a library
// that uses sqlx, you should be aware that the name mappings you expect
// can be overridden by your user's application.
// NameMapper is used to map column names to struct field names. By default,
// it uses strings.ToLower to lowercase struct field names. It can be set
// to whatever you want, but it is encouraged to be set before sqlx is used
// as name-to-field mappings are cached after first use on a type.
var NameMapper = strings.ToLower
var origMapper = reflect.ValueOf(NameMapper)
// Rather than creating on init, this is created when necessary so that
// importers have time to customize the NameMapper.
var mpr *reflectx.Mapper
// mprMu protects mpr.
var mprMu sync.Mutex
// mapper returns a valid mapper using the configured NameMapper func.
func mapper() *reflectx.Mapper {
mprMu.Lock()
defer mprMu.Unlock()
if mpr == nil {
mpr = reflectx.NewMapperFunc("db", NameMapper)
} else if origMapper != reflect.ValueOf(NameMapper) {
// if NameMapper has changed, create a new mapper
mpr = reflectx.NewMapperFunc("db", NameMapper)
origMapper = reflect.ValueOf(NameMapper)
}
return mpr
}
// isScannable takes the reflect.Type and the actual dest value and returns
// whether or not it's Scannable. Something is scannable if:
// - it is not a struct
// - it implements sql.Scanner
// - it has no exported fields
func isScannable(t reflect.Type) bool {
if reflect.PointerTo(t).Implements(_scannerInterface) {
return true
}
if t.Kind() != reflect.Struct {
return true
}
// it's not important that we use the right mapper for this particular object,
// we're only concerned on how many exported fields this struct has
return len(mapper().TypeMap(t).Index) == 0
}
// ColScanner is an interface used by MapScan and SliceScan
type ColScanner interface {
Columns() ([]string, error)
Scan(dest ...interface{}) error
Err() error
}
// Queryer is an interface used by Get and Select
type Queryer interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
Queryx(query string, args ...interface{}) (*Rows, error)
QueryRowx(query string, args ...interface{}) *Row
}
// Execer is an interface used by MustExec and LoadFile
type Execer interface {
Exec(query string, args ...interface{}) (sql.Result, error)
}
// Binder is an interface for something which can bind queries (Tx, DB)
type binder interface {
DriverName() string
Rebind(string) string
BindNamed(string, interface{}) (string, []interface{}, error)
}
// Ext is a union interface which can bind, query, and exec, used by
// NamedQuery and NamedExec.
type Ext interface {
binder
Queryer
Execer
}
// Preparer is an interface used by Preparex.
type Preparer interface {
Prepare(query string) (*sql.Stmt, error)
}
// work around for type assertion with generics
type optionalContainer interface {
getOptions() *dbOptions
}
// getOptions get options for the interface
func getOptions(i interface{}) *dbOptions {
switch v := i.(type) {
case DB:
return v.options
case *DB:
return v.options
case Tx:
return v.options
case *Tx:
return v.options
case Conn:
return v.options
case *Conn:
return v.options
case *Row:
return v.options
case Row:
return v.options
case *Rows:
return v.options
case Rows:
return v.options
case optionalContainer:
return v.getOptions()
default:
return &dbOptions{}
}
}
func mapperFor(i interface{}) *reflectx.Mapper {
switch i := i.(type) {
case DB:
return i.Mapper
case *DB:
return i.Mapper
case Tx:
return i.Mapper
case *Tx:
return i.Mapper
default:
return mapper()
}
}
var _scannerInterface = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
//lint:ignore U1000 ignoring this for now
var _valuerInterface = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
// Row is a reimplementation of sql.Row in order to gain access to the underlying
// sql.Rows.Columns() data, necessary for StructScan.
type Row struct {
err error
options *dbOptions
rows *sql.Rows
Mapper *reflectx.Mapper
}
// Scan is a fixed implementation of sql.Row.Scan, which does not discard the
// underlying error from the internal rows object if it exists.
// Returns ErrMultiRows if the result set contains more than one row.
func (r *Row) Scan(dest ...interface{}) error {
if r.err != nil {
return r.err
}
// clone all []byte that the driver returned since we're about to close
// the Rows in our defer, when we return from this function.
// the contract with the driver.Next(...) interface is that it
// can return slices into read-only temporary memory that's
// only valid until the next Scan/Close.
defer r.rows.Close()
for _, dp := range dest {
if _, ok := dp.(*sql.RawBytes); ok {
return errors.New("sql: RawBytes isn't allowed on Row.Scan")
}
}
if !r.rows.Next() {
if err := r.rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
if err := r.rows.Scan(dest...); err != nil {
return err
}
if r.rows.Next() {
return ErrMultiRows
} else if err := r.rows.Err(); err != nil {
return err
}
// Make sure the query can be processed to completion with no errors.
if err := r.rows.Close(); err != nil {
return err
}
return nil
}
// Columns returns the underlying sql.Rows.Columns(), or the deferred error usually
// returned by Row.Scan()
func (r *Row) Columns() ([]string, error) {
if r.err != nil {
return []string{}, r.err
}
return r.rows.Columns()
}
// ColumnTypes returns the underlying sql.Rows.ColumnTypes(), or the deferred error
func (r *Row) ColumnTypes() ([]*sql.ColumnType, error) {
if r.err != nil {
return []*sql.ColumnType{}, r.err
}
return r.rows.ColumnTypes()
}
// Err returns the error encountered while scanning.
func (r *Row) Err() error {
return r.err
}
// Queryable includes all methods shared by sqlx.DB and sqlx.Tx, allowing
// either type to be used interchangeably.
type Queryable interface {
Ext
ExecerContext
PreparerContext
QueryerContext
Preparer
GetContext(context.Context, interface{}, string, ...interface{}) error
SelectContext(context.Context, interface{}, string, ...interface{}) error
Get(interface{}, string, ...interface{}) error
MustExecContext(context.Context, string, ...interface{}) sql.Result
PreparexContext(context.Context, string) (*Stmt, error)
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
Select(interface{}, string, ...interface{}) error
QueryRow(string, ...interface{}) *sql.Row
PrepareNamedContext(context.Context, string) (*NamedStmt, error)
PrepareNamed(string) (*NamedStmt, error)
Preparex(string) (*Stmt, error)
NamedExec(string, interface{}) (sql.Result, error)
NamedExecContext(context.Context, string, interface{}) (sql.Result, error)
MustExec(string, ...interface{}) sql.Result
NamedQuery(string, interface{}) (*Rows, error)
}
var _ Queryable = (*DB)(nil)
var _ Queryable = (*Tx)(nil)
type dbOptions struct {
unsafe bool
}
func (o *dbOptions) allowMissingFields() bool {
return o.unsafe
}
// WithUnsafe in unsafe mode sqlx will do its best to continue despite scan issues like missing fields
func WithUnsafe() func(*dbOptions) {
return func(opts *dbOptions) {
opts.unsafe = true
}
}
// WithSetUnsafe in unsafe mode sqlx will do its best to continue despite scan issues like missing fields
func WithSetUnsafe(v bool) func(*dbOptions) {
return func(opts *dbOptions) {
opts.unsafe = v
}
}
// DB is a wrapper around sql.DB which keeps track of the driverName upon Open,
// used mostly to automatically bind named queries using the right bindvars.
type DB struct {
*sql.DB
driverName string
Mapper *reflectx.Mapper
options *dbOptions
}
// NewDb returns a new sqlx DB wrapper for a pre-existing *sql.DB. The
// driverName of the original database is required for named query support.
//
// This function now accepts functional options as variadic arguments to configure
// the database instance. Functional options are functions that modify the internal
// dbOptions struct. For example:
//
// db := sqlx.NewDb(existingDB, "mysql", sqlx.WithUnsafe())
//
// The above example enables unsafe mode, which allows sqlx to continue despite
// scan issues like missing fields. You can also use WithSetUnsafe to explicitly
// set the unsafe mode:
//
// db := sqlx.NewDb(existingDB, "mysql", sqlx.WithSetUnsafe(true))
//
// You can pass multiple functional options to configure other aspects of the
// database as needed.
//
//lint:ignore ST1003 changing this would break the package interface.
func NewDb(db *sql.DB, driverName string, args ...func(*dbOptions)) *DB {
opts := &dbOptions{}
for _, arg := range args {
arg(opts)
}
return &DB{DB: db, driverName: driverName, Mapper: mapper(), options: opts}
}
// DriverName returns the driverName passed to the Open function for this DB.
func (db *DB) DriverName() string {
return db.driverName
}
// Open is the same as sql.Open, but returns an *sqlx.DB instead.
func Open(driverName, dataSourceName string, args ...func(*dbOptions)) (*DB, error) {
opts := &dbOptions{}
for _, arg := range args {
arg(opts)
}
db, err := sql.Open(driverName, dataSourceName)
if err != nil {
return nil, err
}
return &DB{DB: db, driverName: driverName, Mapper: mapper(), options: opts}, err
}
// MustOpen is the same as sql.Open, but returns an *sqlx.DB instead and panics on error.
func MustOpen(driverName, dataSourceName string, args ...func(*dbOptions)) *DB {
db, err := Open(driverName, dataSourceName, args...)
if err != nil {
panic(err)
}
return db
}
// MapperFunc sets a new mapper for this db using the default sqlx struct tag
// and the provided mapper function.
func (db *DB) MapperFunc(mf func(string) string) {
db.Mapper = reflectx.NewMapperFunc("db", mf)
}
// Rebind transforms a query from QUESTION to the DB driver's bindvar type.
func (db *DB) Rebind(query string) string {
return Rebind(BindType(db.driverName), query)
}
// Unsafe returns a version of DB which will silently succeed to scan when
// columns in the SQL result have no fields in the destination struct.
// sqlx.Stmt and sqlx.Tx which are created from this DB will inherit its
// safety behavior.
func (db *DB) Unsafe() *DB {
opts := *db.options
opts.unsafe = true
return &DB{DB: db.DB, driverName: db.driverName, Mapper: db.Mapper, options: &opts}
}
// BindNamed binds a query using the DB driver's bindvar type.
func (db *DB) BindNamed(query string, arg interface{}) (string, []interface{}, error) {
return bindNamedMapper(BindType(db.driverName), query, arg, db.Mapper)
}
// NamedQuery using this DB.
// Any named placeholder parameters are replaced with fields from arg.
func (db *DB) NamedQuery(query string, arg interface{}) (*Rows, error) {
return NamedQuery(db, query, arg)
}
// NamedExec using this DB.
// Any named placeholder parameters are replaced with fields from arg.
func (db *DB) NamedExec(query string, arg interface{}) (sql.Result, error) {
return NamedExec(db, query, arg)
}
// Select using this DB.
// Any placeholder parameters are replaced with supplied args.
func (db *DB) Select(dest interface{}, query string, args ...interface{}) error {
return Select(db, dest, query, args...)
}
// Get using this DB.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty or contains more than one row.
func (db *DB) Get(dest interface{}, query string, args ...interface{}) error {
return Get(db, dest, query, args...)
}
// MustBegin starts a transaction, and panics on error. Returns an *sqlx.Tx instead
// of an *sql.Tx.
func (db *DB) MustBegin() *Tx {
tx, err := db.Beginx()
if err != nil {
panic(err)
}
return tx
}
// Beginx begins a transaction and returns an *sqlx.Tx instead of an *sql.Tx.
func (db *DB) Beginx() (*Tx, error) {
tx, err := db.DB.Begin()
if err != nil {
return nil, err
}
return &Tx{Tx: tx, driverName: db.driverName, options: db.options, Mapper: db.Mapper}, err
}
// Queryx queries the database and returns an *sqlx.Rows.
// Any placeholder parameters are replaced with supplied args.
func (db *DB) Queryx(query string, args ...interface{}) (*Rows, error) {
r, err := db.DB.Query(query, args...)
if err != nil {
return nil, err
}
return &Rows{Rows: r, options: db.options, Mapper: db.Mapper}, err
}
// QueryRowx queries the database and returns an *sqlx.Row.
// Any placeholder parameters are replaced with supplied args.
func (db *DB) QueryRowx(query string, args ...interface{}) *Row {
rows, err := db.DB.Query(query, args...)
return &Row{rows: rows, err: err, options: db.options, Mapper: db.Mapper}
}
// MustExec (panic) runs MustExec using this database.
// Any placeholder parameters are replaced with supplied args.
func (db *DB) MustExec(query string, args ...interface{}) sql.Result {
return MustExec(db, query, args...)
}
// Preparex returns an sqlx.Stmt instead of a sql.Stmt
func (db *DB) Preparex(query string) (*Stmt, error) {
return preparexStmt(db, query)
}
// PrepareNamed returns an sqlx.NamedStmt
func (db *DB) PrepareNamed(query string) (*NamedStmt, error) {
return PrepareNamed[any](db, query)
}
// Conn is a wrapper around sql.Conn with extra functionality
type Conn struct {
*sql.Conn
driverName string
options *dbOptions
Mapper *reflectx.Mapper
}
// Tx is an sqlx wrapper around sql.Tx with extra functionality
type Tx struct {
*sql.Tx
driverName string
options *dbOptions
Mapper *reflectx.Mapper
}
// DriverName returns the driverName used by the DB which began this transaction.
func (tx *Tx) DriverName() string {
return tx.driverName
}
// Rebind a query within a transaction's bindvar type.
func (tx *Tx) Rebind(query string) string {
return Rebind(BindType(tx.driverName), query)
}
// Unsafe returns a version of Tx which will silently succeed to scan when
// columns in the SQL result have no fields in the destination struct.
func (tx *Tx) Unsafe() *Tx {
opts := *tx.options
opts.unsafe = true
return &Tx{Tx: tx.Tx, driverName: tx.driverName, options: &opts, Mapper: tx.Mapper}
}
// BindNamed binds a query within a transaction's bindvar type.
func (tx *Tx) BindNamed(query string, arg interface{}) (string, []interface{}, error) {
return bindNamedMapper(BindType(tx.driverName), query, arg, tx.Mapper)
}
// NamedQuery within a transaction.
// Any named placeholder parameters are replaced with fields from arg.
func (tx *Tx) NamedQuery(query string, arg interface{}) (*Rows, error) {
return NamedQuery(tx, query, arg)
}
// NamedExec a named query within a transaction.
// Any named placeholder parameters are replaced with fields from arg.
func (tx *Tx) NamedExec(query string, arg interface{}) (sql.Result, error) {
return NamedExec(tx, query, arg)
}
// Select within a transaction.
// Any placeholder parameters are replaced with supplied args.
func (tx *Tx) Select(dest interface{}, query string, args ...interface{}) error {
return Select(tx, dest, query, args...)
}
// Queryx within a transaction.
// Any placeholder parameters are replaced with supplied args.
func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) {
r, err := tx.Tx.Query(query, args...)
if err != nil {
return nil, err
}
return &Rows{Rows: r, options: tx.options, Mapper: tx.Mapper}, err
}
// QueryRowx within a transaction.
// Any placeholder parameters are replaced with supplied args.
func (tx *Tx) QueryRowx(query string, args ...interface{}) *Row {
rows, err := tx.Tx.Query(query, args...)
return &Row{rows: rows, err: err, options: tx.options, Mapper: tx.Mapper}
}
// Get within a transaction.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty or contains more than one row.
func (tx *Tx) Get(dest interface{}, query string, args ...interface{}) error {
return Get(tx, dest, query, args...)
}
// MustExec runs MustExec within a transaction.
// Any placeholder parameters are replaced with supplied args.
func (tx *Tx) MustExec(query string, args ...interface{}) sql.Result {
return MustExec(tx, query, args...)
}
// Preparex a statement within a transaction.
func (tx *Tx) Preparex(query string) (*Stmt, error) {
return preparexStmt(tx, query)
}
// Stmtx returns a version of the prepared statement which runs within a transaction. Provided
// stmt can be either *sql.Stmt or *sqlx.Stmt.
func (tx *Tx) Stmtx(stmt interface{}) *GenericStmt[any] {
var s *sql.Stmt
switch v := stmt.(type) {
case Stmt:
s = v.Stmt
case *Stmt:
s = v.Stmt
case *sql.Stmt:
s = v
default:
panic(fmt.Sprintf("non-statement type %v passed to Stmtx", reflect.ValueOf(stmt).Type()))
}
return &GenericStmt[any]{Stmt: tx.Stmt(s), Mapper: tx.Mapper, options: tx.options}
}
// NamedStmt returns a version of the prepared statement which runs within a transaction.
func (tx *Tx) NamedStmt(stmt *NamedStmt) *NamedStmt {
return &NamedStmt{
QueryString: stmt.QueryString,
Params: stmt.Params,
Stmt: tx.Stmtx(stmt.Stmt),
}
}
// PrepareNamed returns an sqlx.NamedStmt
func (tx *Tx) PrepareNamed(query string) (*NamedStmt, error) {
return PrepareNamed[any](tx, query)
}
// GenericStmt is an sqlx wrapper around sql.Stmt with extra functionality
type GenericStmt[T any] struct {
*sql.Stmt
options *dbOptions
Mapper *reflectx.Mapper
}
type Stmt = GenericStmt[any]
// Unsafe returns a version of Stmt which will silently succeed to scan when
// columns in the SQL result have no fields in the destination struct.
func (s *GenericStmt[T]) Unsafe() *GenericStmt[T] {
opts := *s.options
opts.unsafe = true
return &GenericStmt[T]{Stmt: s.Stmt, options: &opts, Mapper: s.Mapper}
}
// Select using the prepared statement.
// Any placeholder parameters are replaced with supplied args.
func (s *GenericStmt[T]) Select(dest interface{}, args ...interface{}) error {
return Select(&qStmt[T]{s}, dest, "", args...)
}
// Get using the prepared statement.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty or contains more than one row.
func (s *GenericStmt[T]) Get(dest interface{}, args ...interface{}) error {
return Get(&qStmt[T]{s}, dest, "", args...)
}
// MustExec (panic) using this statement. Note that the query portion of the error
// output will be blank, as Stmt does not expose its query.
// Any placeholder parameters are replaced with supplied args.
func (s *GenericStmt[T]) MustExec(args ...interface{}) sql.Result {
return MustExec(&qStmt[T]{s}, "", args...)
}
// QueryRowx using this statement.
// Any placeholder parameters are replaced with supplied args.
func (s *GenericStmt[T]) QueryRowx(args ...interface{}) *Row {
qs := &qStmt[T]{s}
return qs.QueryRowx("", args...)
}
// Queryx using this statement.
// Any placeholder parameters are replaced with supplied args.
func (s *GenericStmt[T]) Queryx(args ...interface{}) (*Rows, error) {
qs := &qStmt[T]{s}
return qs.Queryx("", args...)
}
// One get one row using the prepared statement.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty or contains more than one row.
func (s *GenericStmt[T]) One(args ...interface{}) (T, error) {
var dest T
err := Get(&qStmt[T]{s}, &dest, "", args...)
return dest, err
}
// All performs a query using the NamedStmt and returns all rows for use with range.
func (s *GenericStmt[T]) All(args ...interface{}) iter.Seq2[T, error] {
rows, err := s.Queryx(args...)
if err != nil {
panic(err)
}
return func(yield func(T, error) bool) {
defer func(rows *Rows) {
_ = rows.Close()
}(rows)
for rows.Next() {
var dest T
err := rows.StructScan(&dest)
if !yield(dest, err) {
return
}
}
}
}
// List performs a query using the statement and returns all rows as a slice of T.
func (s *GenericStmt[T]) List(args ...interface{}) ([]T, error) {
var dests []T
err := s.Select(&dests, args...)
return dests, err
}
// Prepare returns a transaction-specific prepared statement from
// an existing statement.
//
// The returned statement operates within the transaction and will be closed
// when the transaction has been committed or rolled back.
func (s *GenericStmt[T]) Prepare(ndb Queryable) *GenericStmt[T] {
tx, ok := ndb.(*Tx)
if !ok {
// not needed
return s
}
return &GenericStmt[T]{Stmt: tx.Stmt(s.Stmt), options: s.options, Mapper: s.Mapper}
}
// getOptions work around type assertions with generics
func (n *GenericStmt[T]) getOptions() *dbOptions {
return n.options
}
// qStmt is an unexposed wrapper which lets you use a Stmt as a Queryer & Execer by
// implementing those interfaces and ignoring the `query` argument.
type qStmt[T any] struct{ Stmt *GenericStmt[T] }
// getOptions work around type assertions with generics
func (q *qStmt[T]) getOptions() *dbOptions {
return q.Stmt.options
}
func (q *qStmt[T]) Query(query string, args ...interface{}) (*sql.Rows, error) {
return q.Stmt.Query(args...)
}
func (q *qStmt[T]) Queryx(query string, args ...interface{}) (*Rows, error) {
r, err := q.Stmt.Query(args...)
if err != nil {
return nil, err
}
return &Rows{Rows: r, options: q.Stmt.options, Mapper: q.Stmt.Mapper}, err
}
func (q *qStmt[T]) QueryRowx(query string, args ...interface{}) *Row {
rows, err := q.Stmt.Query(args...)
return &Row{rows: rows, err: err, options: q.Stmt.options, Mapper: q.Stmt.Mapper}
}
func (q *qStmt[T]) Exec(query string, args ...interface{}) (sql.Result, error) {
return q.Stmt.Exec(args...)
}
// Rows is a wrapper around sql.Rows which caches costly reflect operations
// during a looped StructScan
type Rows struct {
*sql.Rows
options *dbOptions
Mapper *reflectx.Mapper
// these fields cache memory use for a rows during iteration w/ structScan
started bool
fields [][]int
values []interface{}
}
// SliceScan using this Rows.
func (r *Rows) SliceScan() ([]interface{}, error) {
return SliceScan(r)
}
// MapScan using this Rows.
func (r *Rows) MapScan(dest map[string]interface{}) error {
return MapScan(r, dest)
}
// StructScan is like sql.Rows.Scan, but scans a single Row into a single Struct.
// Use this and iterate over Rows manually when the memory load of Select() might be
// prohibitive. *Rows.StructScan caches the reflect work of matching up column
// positions to fields to avoid that overhead per scan, which means it is not safe
// to run StructScan on the same Rows instance with different struct types.
func (r *Rows) StructScan(dest interface{}) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
return errors.New("must pass a pointer, not a value, to StructScan destination")
}
v = v.Elem()
if !r.started {
columns, err := r.Columns()
if err != nil {
return err
}
m := r.Mapper
r.fields = m.TraversalsByName(v.Type(), columns)
if !getOptions(r).allowMissingFields() {
// if we are not unsafe and are missing fields, return an error
if f, err := missingFields(r.fields); err != nil {
return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
}
}
r.values = make([]interface{}, len(columns))
r.started = true
}
err := fieldsByTraversal(v, r.fields, r.values)
if err != nil {
return err
}
// scan into the struct field pointers and append to our results
err = r.Scan(r.values...)
if err != nil {
return err
}
return r.Err()
}
// NextResultSet moves to the next resultset if available and resets the field cache.
func (r *Rows) NextResultSet() bool {
if !r.Rows.NextResultSet() {
return false
}
// reset fields cache
r.started = false
return true
}
// AllRows returns an iter.Seq2 for ranging over rows. The second result is an error object which may be non-nil due to scanning errors. Calling code should check error in the loop.
func AllRows[T any](rows *Rows) iter.Seq2[T, error] {
return func(yield func(T, error) bool) {
defer func(rows *Rows) {
_ = rows.Close()
}(rows)
for rows.Next() {
var dest T
err := rows.StructScan(&dest)
if !yield(dest, err) {
return
}
}
}
}
// Connect to a database and verify with a ping.
func Connect(driverName, dataSourceName string, args ...func(*dbOptions)) (*DB, error) {
db, err := Open(driverName, dataSourceName, args...)
if err != nil {
return nil, err
}
err = db.Ping()
if err != nil {
db.Close()
return nil, err
}
return db, nil
}
// MustConnect connects to a database and panics on error.
func MustConnect(driverName, dataSourceName string, args ...func(*dbOptions)) *DB {
db, err := Connect(driverName, dataSourceName, args...)
if err != nil {
panic(err)
}
return db
}
// Preparex prepares a statement.
func Preparex[T any](p Preparer, query string) (*GenericStmt[T], error) {
s, err := p.Prepare(query)
if err != nil {
return nil, err
}
return &GenericStmt[T]{Stmt: s, options: getOptions(p), Mapper: mapperFor(p)}, err
}
// preparexStmt returns a Stmt, a workaround for type aliases and generics until Go 1.24
func preparexStmt(p Preparer, query string) (*Stmt, error) {
s, err := p.Prepare(query)
if err != nil {
return nil, err
}
return &Stmt{Stmt: s, options: getOptions(p), Mapper: mapperFor(p)}, err
}
// Select executes a query using the provided Queryer, and StructScans each row
// into dest, which must be a slice. If the slice elements are scannable, then
// the result set must have only one column. Otherwise, StructScan is used.
// The *sql.Rows are closed automatically.
// Any placeholder parameters are replaced with supplied args.
func Select(q Queryer, dest interface{}, query string, args ...interface{}) error {
rows, err := q.Queryx(query, args...)
if err != nil {
return err
}
// if something happens here, we want to make sure the rows are Closed
defer rows.Close()
return scanAll(rows, dest, false)
}
// Get does a QueryRow using the provided Queryer, and scans the resulting row
// to dest. If dest is scannable, the result must only have one column. Otherwise,
// StructScan is used. Get will return sql.ErrNoRows like row.Scan would.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty or contains more than one row.
func Get(q Queryer, dest interface{}, query string, args ...interface{}) error {
r := q.QueryRowx(query, args...)
return r.scanAny(dest, false)
}
// One does a QueryRow using the provided Queryer, and scans the resulting row.
// If dest is scannable, the result must only have one column. Otherwise,
// StructScan is used. Get will return sql.ErrNoRows like row.Scan would.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty or contains more than one row.
func One[T any](q Queryer, query string, args ...interface{}) (T, error) {
r := q.QueryRowx(query, args...)
var dest T
err := r.scanAny(&dest, false)
return dest, err
}
// List executes a query using the provided Queryer, and returns a slice of T for each row.
func List[T any](q Queryer, query string, args ...interface{}) ([]T, error) {
var dest []T
err := Select(q, &dest, query, args...)
return dest, err
}
// LoadFile exec's every statement in a file (as a single call to Exec).
// LoadFile may return a nil *sql.Result if errors are encountered locating or
// reading the file at path. LoadFile reads the entire file into memory, so it
// is not suitable for loading large data dumps, but can be useful for initializing
// schemas or loading indexes.
//
// FIXME: this does not really work with multi-statement files for mattn/go-sqlite3
// or the go-mysql-driver/mysql drivers; pq seems to be an exception here. Detecting
// this by requiring something with DriverName() and then attempting to split the
// queries will be difficult to get right, and its current driver-specific behavior
// is deemed at least not complex in its incorrectness.
func LoadFile(e Execer, path string) (*sql.Result, error) {
realpath, err := filepath.Abs(path)
if err != nil {
return nil, err
}
contents, err := os.ReadFile(realpath)
if err != nil {
return nil, err
}
res, err := e.Exec(string(contents))
return &res, err
}
// MustExec execs the query using e and panics if there was an error.
// Any placeholder parameters are replaced with supplied args.
func MustExec(e Execer, query string, args ...interface{}) sql.Result {
res, err := e.Exec(query, args...)
if err != nil {
panic(err)
}
return res
}
// SliceScan using this Rows.
func (r *Row) SliceScan() ([]interface{}, error) {
return SliceScan(r)
}
// MapScan using this Rows.
func (r *Row) MapScan(dest map[string]interface{}) error {
return MapScan(r, dest)
}
func (r *Row) scanAny(dest interface{}, structOnly bool) error {
if r.err != nil {
return r.err
}
if r.rows == nil {
r.err = sql.ErrNoRows
return r.err
}
defer r.rows.Close()
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
return errors.New("must pass a pointer, not a value, to StructScan destination")
}
if v.IsNil() {
return errors.New("nil pointer passed to StructScan destination")
}
base := reflectx.Deref(v.Type())
scannable := isScannable(base)
if structOnly && scannable {
return structOnlyError(base)
}
columns, err := r.Columns()
if err != nil {
return err
}
if scannable && len(columns) > 1 {
return fmt.Errorf("scannable dest type %s with >1 columns (%d) in result", base.Kind(), len(columns))
}
if scannable {
return r.Scan(dest)
}
m := r.Mapper
fields := m.TraversalsByName(v.Type(), columns)
if !getOptions(r).allowMissingFields() {
// if we are not unsafe and are missing fields, return an error
if f, err := missingFields(fields); err != nil {
return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
}
}
values := make([]interface{}, len(columns))
err = fieldsByTraversal(v, fields, values)
if err != nil {
return err
}
// scan into the struct field pointers and append to our results
return r.Scan(values...)
}
// StructScan a single Row into dest.
func (r *Row) StructScan(dest interface{}) error {
return r.scanAny(dest, true)
}
// SliceScan a row, returning a []interface{} with values similar to MapScan.
// This function is primarily intended for use where the number of columns
// is not known. Because you can pass an []interface{} directly to Scan,
// it's recommended that you do that as it will not have to allocate new
// slices per row.
func SliceScan(r ColScanner) ([]interface{}, error) {
// ignore r.started, since we needn't use reflect for anything.
columns, err := r.Columns()
if err != nil {
return []interface{}{}, err
}
values := make([]interface{}, len(columns))
for i := range values {
values[i] = new(interface{})
}
err = r.Scan(values...)
if err != nil {
return values, err
}
for i := range columns {
values[i] = *(values[i].(*interface{}))
}
return values, r.Err()
}
// MapScan scans a single Row into the dest map[string]interface{}.
// Use this to get results for SQL that might not be under your control
// (for instance, if you're building an interface for an SQL server that
// executes SQL from input). Please do not use this as a primary interface!
// This will modify the map sent to it in place, so reuse the same map with
// care. Columns which occur more than once in the result will overwrite
// each other!
func MapScan(r ColScanner, dest map[string]interface{}) error {
// ignore r.started, since we needn't use reflect for anything.
columns, err := r.Columns()
if err != nil {
return err
}
values := make([]interface{}, len(columns))
for i := range values {
values[i] = new(interface{})
}
err = r.Scan(values...)
if err != nil {
return err
}
for i, column := range columns {
dest[column] = *(values[i].(*interface{}))
}
return r.Err()
}
type rowsi interface {
Close() error
Columns() ([]string, error)
Err() error
Next() bool
Scan(...interface{}) error
}
// structOnlyError returns an error appropriate for type when a non-scannable
// struct is expected but something else is given
func structOnlyError(t reflect.Type) error {
isStruct := t.Kind() == reflect.Struct
isScanner := reflect.PointerTo(t).Implements(_scannerInterface)
if !isStruct {
return fmt.Errorf("expected %s but got %s", reflect.Struct, t.Kind())
}
if isScanner {
return fmt.Errorf("structscan expects a struct dest but the provided struct type %s implements scanner", t.Name())
}
return fmt.Errorf("expected a struct, but struct %s has no exported fields", t.Name())
}
// scanAll scans all rows into a destination, which must be a slice of any
// type. It resets the slice length to zero before appending each element to
// the slice. If the destination slice type is a Struct, then StructScan will
// be used on each row. If the destination is some other kind of base type,
// then each row must only have one column which can scan into that type. This
// allows you to do something like:
//
// rows, _ := db.Query("select id from people;")
// var ids []int
// scanAll(rows, &ids, false)
//
// and ids will be a list of the id results. I realize that this is a desirable
// interface to expose to users, but for now it will only be exposed via changes
// to `Get` and `Select`. The reason that this has been implemented like this is
// this is the only way to not duplicate reflect work in the new API while
// maintaining backwards compatibility.
func scanAll(rows rowsi, dest interface{}, structOnly bool) error {
var v, vp reflect.Value
value := reflect.ValueOf(dest)
// json.Unmarshal returns errors for these
if value.Kind() != reflect.Ptr {
return errors.New("must pass a pointer, not a value, to StructScan destination")
}
if value.IsNil() {
return errors.New("nil pointer passed to StructScan destination")
}
direct := reflect.Indirect(value)
slice, err := baseType(value.Type(), reflect.Slice)
if err != nil {
return err
}
direct.SetLen(0)
isPtr := slice.Elem().Kind() == reflect.Ptr
base := reflectx.Deref(slice.Elem())
scannable := isScannable(base)
if structOnly && scannable {
return structOnlyError(base)
}
columns, err := rows.Columns()
if err != nil {
return err
}
// if it's a base type make sure it only has 1 column; if not return an error
if scannable && len(columns) > 1 {
return fmt.Errorf("non-struct dest type %s with >1 columns (%d)", base.Kind(), len(columns))
}
if !scannable {
var values []interface{}
var m *reflectx.Mapper
switch rows := rows.(type) {
case *Rows:
m = rows.Mapper
default:
m = mapper()
}
fields := m.TraversalsByName(base, columns)
if !getOptions(rows).allowMissingFields() {
// if we are not unsafe and are missing fields, return an error
if f, err := missingFields(fields); err != nil {
return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
}
}
values = make([]interface{}, len(columns))
for rows.Next() {
// create a new struct type (which returns PtrTo) and indirect it
vp = reflect.New(base)
v = reflect.Indirect(vp)
err = fieldsByTraversal(v, fields, values)
if err != nil {
return err
}
// scan into the struct field pointers and append to our results
err = rows.Scan(values...)
if err != nil {
return err
}
if isPtr {
direct.Set(reflect.Append(direct, vp))
} else {
direct.Set(reflect.Append(direct, v))
}
}
} else {
for rows.Next() {
vp = reflect.New(base)
err = rows.Scan(vp.Interface())
if err != nil {
return err
}
// append
if isPtr {
direct.Set(reflect.Append(direct, vp))
} else {
direct.Set(reflect.Append(direct, reflect.Indirect(vp)))
}
}
}
return rows.Err()
}
// FIXME: StructScan was the very first bit of API in sqlx, and now unfortunately
// it doesn't really feel like it's named properly. There is an incongruency
// between this and the way that StructScan (which might better be ScanStruct
// anyway) works on a rows object.
// StructScan all rows from an sql.Rows or an sqlx.Rows into the dest slice.
// StructScan will scan in the entire rows result, so if you do not want to
// allocate structs for the entire result, use Queryx and see sqlx.Rows.StructScan.
// If rows is sqlx.Rows, it will use its mapper, otherwise it will use the default.
func StructScan(rows rowsi, dest interface{}) error {
return scanAll(rows, dest, true)
}
// reflect helpers
func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) {
t = reflectx.Deref(t)
if t.Kind() != expected {
return nil, fmt.Errorf("expected %s but got %s", expected, t.Kind())
}
return t, nil
}
// fieldsByName fills a values interface with fields from the passed value based
// on the traversals in int. If ptrs is true, return addresses instead of values.
// We write this instead of using FieldsByName to save allocations and map lookups
// when iterating over many rows. Empty traversals will get an interface pointer.
// Because of the necessity of requesting ptrs or values, it's considered a bit too
// specialized for inclusion in reflectx itself.
func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}) error {
v = reflect.Indirect(v)
if v.Kind() != reflect.Struct {
return errors.New("argument not a struct")
}
for i, traversal := range traversals {
if len(traversal) == 0 {
values[i] = new(interface{})
} else if len(traversal) == 1 {
values[i] = reflectx.FieldByIndexes(v, traversal).Addr().Interface()
} else {
// reflectx.FieldByIndexes initializes pointer fields, including pointers to nested structs.
// Use optDest to delay it until the first non-NULL value is scanned into a field of a nested struct.
// That way we can support LEFT JOINs with optional nested structs.
values[i] = optDest(func() interface{} {
return reflectx.FieldByIndexes(v, traversal).Addr().Interface()
})
}
}
return nil
}
func missingFields(traversals [][]int) (field int, err error) {
for i, t := range traversals {
if len(t) == 0 {
return i, errors.New("missing field")
}
}
return 0, nil
}
// optDest will only forward the Scan to the nested value if
// the database value is not nil.
type optDest func() interface{}
// Scan implements sql.Scanner.
func (dest optDest) Scan(src interface{}) error {
if src == nil {
return nil
}
return convertAssign(dest(), src)
}
package sqlx
import (
"context"
"database/sql"
"fmt"
"os"
"path/filepath"
"reflect"
)
// ConnectContext to a database and verify with a ping.
func ConnectContext(ctx context.Context, driverName, dataSourceName string) (*DB, error) {
db, err := Open(driverName, dataSourceName)
if err != nil {
return db, err
}
err = db.PingContext(ctx)
return db, err
}
// QueryerContext is an interface used by GetContext and SelectContext
type QueryerContext interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error)
QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row
}
// PreparerContext is an interface used by PreparexContext.
type PreparerContext interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
}
// ExecerContext is an interface used by MustExecContext and LoadFileContext
type ExecerContext interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}
// ExtContext is a union interface which can bind, query, and exec, with Context
// used by NamedQueryContext and NamedExecContext.
type ExtContext interface {
binder
QueryerContext
ExecerContext
}
// SelectContext executes a query using the provided Queryer, and StructScans
// each row into dest, which must be a slice. If the slice elements are
// scannable, then the result set must have only one column. Otherwise,
// StructScan is used. The *sql.Rows are closed automatically.
// Any placeholder parameters are replaced with supplied args.
func SelectContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error {
rows, err := q.QueryxContext(ctx, query, args...)
if err != nil {
return err
}
// if something happens here, we want to make sure the rows are Closed
defer rows.Close()
return scanAll(rows, dest, false)
}
// PreparexContext prepares a statement.
//
// The provided context is used for the preparation of the statement, not for
// the execution of the statement.
func PreparexContext[T any](ctx context.Context, p PreparerContext, query string) (*GenericStmt[T], error) {
s, err := p.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
return &GenericStmt[T]{Stmt: s, options: getOptions(p), Mapper: mapperFor(p)}, err
}
// preparexContextStmt returns a Stmt until type aliases support generics. It sounds like PreparexContext[any] returning Stmt will work in 1.24
func preparexContextStmt(ctx context.Context, p PreparerContext, query string) (*Stmt, error) {
s, err := p.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
return &Stmt{Stmt: s, options: getOptions(p), Mapper: mapperFor(p)}, err
}
// GetContext does a QueryRow using the provided Queryer, and scans the
// resulting row to dest. If dest is scannable, the result must only have one
// column. Otherwise, StructScan is used. Get will return sql.ErrNoRows like
// row.Scan would. Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
func GetContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error {
r := q.QueryRowxContext(ctx, query, args...)
return r.scanAny(dest, false)
}
// OneContext does a QueryRow using the provided Queryer, and scans the
// resulting row to dest. If dest is scannable, the result must only have one
// column. Otherwise, StructScan is used. Get will return sql.ErrNoRows like
// row.Scan would. Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
func OneContext[T any](ctx context.Context, q QueryerContext, query string, args ...interface{}) (T, error) {
var dest T
r := q.QueryRowxContext(ctx, query, args...)
err := r.scanAny(&dest, false)
return dest, err
}
// ListContext executes a query using the provided Queryer, and returns a slice of T for each row.
func ListContext[T any](ctx context.Context, q QueryerContext, query string, args ...interface{}) ([]T, error) {
var dest []T
err := SelectContext(ctx, q, &dest, query, args...)
return dest, err
}
// LoadFileContext exec's every statement in a file (as a single call to Exec).
// LoadFileContext may return a nil *sql.Result if errors are encountered
// locating or reading the file at path. LoadFile reads the entire file into
// memory, so it is not suitable for loading large data dumps, but can be useful
// for initializing schemas or loading indexes.
//
// FIXME: this does not really work with multi-statement files for mattn/go-sqlite3
// or the go-mysql-driver/mysql drivers; pq seems to be an exception here. Detecting
// this by requiring something with DriverName() and then attempting to split the
// queries will be difficult to get right, and its current driver-specific behavior
// is deemed at least not complex in its incorrectness.
func LoadFileContext(ctx context.Context, e ExecerContext, path string) (*sql.Result, error) {
realpath, err := filepath.Abs(path)
if err != nil {
return nil, err
}
contents, err := os.ReadFile(realpath)
if err != nil {
return nil, err
}
res, err := e.ExecContext(ctx, string(contents))
return &res, err
}
// MustExecContext execs the query using e and panics if there was an error.
// Any placeholder parameters are replaced with supplied args.
func MustExecContext(ctx context.Context, e ExecerContext, query string, args ...interface{}) sql.Result {
res, err := e.ExecContext(ctx, query, args...)
if err != nil {
panic(err)
}
return res
}
// PrepareNamedContext returns an sqlx.NamedStmt
func (db *DB) PrepareNamedContext(ctx context.Context, query string) (*NamedStmt, error) {
return PrepareNamedContext[any](ctx, db, query)
}
// NamedQueryContext using this DB.
// Any named placeholder parameters are replaced with fields from arg.
func (db *DB) NamedQueryContext(ctx context.Context, query string, arg interface{}) (*Rows, error) {
return NamedQueryContext(ctx, db, query, arg)
}
// NamedExecContext using this DB.
// Any named placeholder parameters are replaced with fields from arg.
func (db *DB) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) {
return NamedExecContext(ctx, db, query, arg)
}
// SelectContext using this DB.
// Any placeholder parameters are replaced with supplied args.
func (db *DB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
return SelectContext(ctx, db, dest, query, args...)
}
// GetContext using this DB.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
func (db *DB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
return GetContext(ctx, db, dest, query, args...)
}
// PreparexContext returns an sqlx.Stmt instead of a sql.Stmt.
//
// The provided context is used for the preparation of the statement, not for
// the execution of the statement.
func (db *DB) PreparexContext(ctx context.Context, query string) (*Stmt, error) {
return preparexContextStmt(ctx, db, query)
}
// QueryxContext queries the database and returns an *sqlx.Rows.
// Any placeholder parameters are replaced with supplied args.
func (db *DB) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
r, err := db.DB.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
return &Rows{Rows: r, options: db.options, Mapper: db.Mapper}, err
}
// QueryRowxContext queries the database and returns an *sqlx.Row.
// Any placeholder parameters are replaced with supplied args.
func (db *DB) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row {
rows, err := db.DB.QueryContext(ctx, query, args...)
return &Row{rows: rows, err: err, options: db.options, Mapper: db.Mapper}
}
// MustBeginTx starts a transaction, and panics on error. Returns an *sqlx.Tx instead
// of an *sql.Tx.
//
// The provided context is used until the transaction is committed or rolled
// back. If the context is canceled, the sql package will roll back the
// transaction. Tx.Commit will return an error if the context provided to
// MustBeginContext is canceled.
func (db *DB) MustBeginTx(ctx context.Context, opts *sql.TxOptions) *Tx {
tx, err := db.BeginTxx(ctx, opts)
if err != nil {
panic(err)
}
return tx
}
// MustExecContext (panic) runs MustExec using this database.
// Any placeholder parameters are replaced with supplied args.
func (db *DB) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result {
return MustExecContext(ctx, db, query, args...)
}
// BeginTxx begins a transaction and returns an *sqlx.Tx instead of an
// *sql.Tx.
//
// The provided context is used until the transaction is committed or rolled
// back. If the context is canceled, the sql package will roll back the
// transaction. Tx.Commit will return an error if the context provided to
// BeginxContext is canceled.
func (db *DB) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
tx, err := db.DB.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &Tx{Tx: tx, driverName: db.driverName, options: db.options, Mapper: db.Mapper}, err
}
// Connx returns an *sqlx.Conn instead of an *sql.Conn.
func (db *DB) Connx(ctx context.Context) (*Conn, error) {
conn, err := db.DB.Conn(ctx)
if err != nil {
return nil, err
}
return &Conn{Conn: conn, driverName: db.driverName, options: db.options, Mapper: db.Mapper}, nil
}
// BeginTxx begins a transaction and returns an *sqlx.Tx instead of an
// *sql.Tx.
//
// The provided context is used until the transaction is committed or rolled
// back. If the context is canceled, the sql package will roll back the
// transaction. Tx.Commit will return an error if the context provided to
// BeginxContext is canceled.
func (c *Conn) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
tx, err := c.Conn.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &Tx{Tx: tx, driverName: c.driverName, options: c.options, Mapper: c.Mapper}, err
}
// SelectContext using this Conn.
// Any placeholder parameters are replaced with supplied args.
func (c *Conn) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
return SelectContext(ctx, c, dest, query, args...)
}
// GetContext using this Conn.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
func (c *Conn) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
return GetContext(ctx, c, dest, query, args...)
}
// PreparexContext returns an sqlx.Stmt instead of a sql.Stmt.
//
// The provided context is used for the preparation of the statement, not for
// the execution of the statement.
func (c *Conn) PreparexContext(ctx context.Context, query string) (*Stmt, error) {
return preparexContextStmt(ctx, c, query)
}
// QueryxContext queries the database and returns an *sqlx.Rows.
// Any placeholder parameters are replaced with supplied args.
func (c *Conn) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
r, err := c.Conn.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
return &Rows{Rows: r, options: c.options, Mapper: c.Mapper}, err
}
// QueryRowxContext queries the database and returns an *sqlx.Row.
// Any placeholder parameters are replaced with supplied args.
func (c *Conn) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row {
rows, err := c.Conn.QueryContext(ctx, query, args...)
return &Row{rows: rows, err: err, options: c.options, Mapper: c.Mapper}
}
// Rebind a query within a Conn's bindvar type.
func (c *Conn) Rebind(query string) string {
return Rebind(BindType(c.driverName), query)
}
// StmtxContext returns a version of the prepared statement which runs within a
// transaction. Provided stmt can be either *sql.Stmt or *sqlx.Stmt.
func (tx *Tx) StmtxContext(ctx context.Context, stmt interface{}) *GenericStmt[any] {
var s *sql.Stmt
switch v := stmt.(type) {
case Stmt:
s = v.Stmt
case *Stmt:
s = v.Stmt
case *sql.Stmt:
s = v
default:
panic(fmt.Sprintf("non-statement type %v passed to Stmtx", reflect.ValueOf(stmt).Type()))
}
return &GenericStmt[any]{Stmt: tx.StmtContext(ctx, s), Mapper: tx.Mapper, options: tx.options}
}
// NamedStmtContext returns a version of the prepared statement which runs
// within a transaction.
func (tx *Tx) NamedStmtContext(ctx context.Context, stmt *NamedStmt) *NamedStmt {
return &NamedStmt{
QueryString: stmt.QueryString,
Params: stmt.Params,
Stmt: tx.StmtxContext(ctx, stmt.Stmt),
}
}
// PreparexContext returns an sqlx.Stmt instead of a sql.Stmt.
//
// The provided context is used for the preparation of the statement, not for
// the execution of the statement.
func (tx *Tx) PreparexContext(ctx context.Context, query string) (*Stmt, error) {
return preparexContextStmt(ctx, tx, query)
}
// PrepareNamedContext returns an sqlx.NamedStmt
func (tx *Tx) PrepareNamedContext(ctx context.Context, query string) (*NamedStmt, error) {
return PrepareNamedContext[any](ctx, tx, query)
}
// MustExecContext runs MustExecContext within a transaction.
// Any placeholder parameters are replaced with supplied args.
func (tx *Tx) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result {
return MustExecContext(ctx, tx, query, args...)
}
// QueryxContext within a transaction and context.
// Any placeholder parameters are replaced with supplied args.
func (tx *Tx) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
r, err := tx.Tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
return &Rows{Rows: r, options: tx.options, Mapper: tx.Mapper}, err
}
// SelectContext within a transaction and context.
// Any placeholder parameters are replaced with supplied args.
func (tx *Tx) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
return SelectContext(ctx, tx, dest, query, args...)
}
// GetContext within a transaction and context.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
func (tx *Tx) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
return GetContext(ctx, tx, dest, query, args...)
}
// QueryRowxContext within a transaction and context.
// Any placeholder parameters are replaced with supplied args.
func (tx *Tx) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row {
rows, err := tx.Tx.QueryContext(ctx, query, args...)
return &Row{rows: rows, err: err, options: tx.options, Mapper: tx.Mapper}
}
// NamedExecContext using this Tx.
// Any named placeholder parameters are replaced with fields from arg.
func (tx *Tx) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) {
return NamedExecContext(ctx, tx, query, arg)
}
// SelectContext using the prepared statement.
// Any placeholder parameters are replaced with supplied args.
func (s *GenericStmt[T]) SelectContext(ctx context.Context, dest interface{}, args ...interface{}) error {
return SelectContext(ctx, &qStmt[T]{s}, dest, "", args...)
}
// GetContext using the prepared statement.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
func (s *GenericStmt[T]) GetContext(ctx context.Context, dest interface{}, args ...interface{}) error {
return GetContext(ctx, &qStmt[T]{s}, dest, "", args...)
}
// MustExecContext (panic) using this statement. Note that the query portion of
// the error output will be blank, as Stmt does not expose its query.
// Any placeholder parameters are replaced with supplied args.
func (s *GenericStmt[T]) MustExecContext(ctx context.Context, args ...interface{}) sql.Result {
return MustExecContext(ctx, &qStmt[T]{s}, "", args...)
}
// QueryRowxContext using this statement.
// Any placeholder parameters are replaced with supplied args.
func (s *GenericStmt[T]) QueryRowxContext(ctx context.Context, args ...interface{}) *Row {
qs := &qStmt[T]{s}
return qs.QueryRowxContext(ctx, "", args...)
}
// QueryxContext using this statement.
// Any placeholder parameters are replaced with supplied args.
func (s *GenericStmt[T]) QueryxContext(ctx context.Context, args ...interface{}) (*Rows, error) {
qs := &qStmt[T]{s}
return qs.QueryxContext(ctx, "", args...)
}
func (q *qStmt[T]) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
return q.Stmt.QueryContext(ctx, args...)
}
func (q *qStmt[T]) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
r, err := q.Stmt.QueryContext(ctx, args...)
if err != nil {
return nil, err
}
return &Rows{Rows: r, options: q.Stmt.options, Mapper: q.Stmt.Mapper}, err
}
func (q *qStmt[T]) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row {
rows, err := q.Stmt.QueryContext(ctx, args...)
return &Row{rows: rows, err: err, options: q.Stmt.options, Mapper: q.Stmt.Mapper}
}
func (q *qStmt[T]) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return q.Stmt.ExecContext(ctx, args...)
}
package sqlx
import "context"
type (
depthKey string
transactionPointer string
)
const (
dbKeyDepth = depthKey("tradeDBTXDepth")
txKey = transactionPointer("transactionPointer")
)
func getTxDepth(ctx context.Context) uint8 {
v := ctx.Value(dbKeyDepth)
if v == nil {
return 0
}
return v.(uint8)
}
func getTransactionDB(ctx context.Context) *Tx {
v := ctx.Value(txKey)
if v == nil {
return nil
}
return v.(*Tx)
}
func TransactContext(ctx context.Context, db Queryable, txFunc func(context.Context, Queryable) error) error {
depth := getTxDepth(ctx)
tx := getTransactionDB(ctx)
var err error
if tx != nil {
// already in a transaction, push down the stack
ctx = context.WithValue(ctx, dbKeyDepth, depth+1)
} else if base, ok := db.(*DB); ok {
tx, err = base.Beginx()
if err != nil {
return err
}
ctx = context.WithValue(ctx, txKey, tx)
ctx = context.WithValue(ctx, dbKeyDepth, depth+1)
}
defer func() {
if depth != 0 {
return
}
if p := recover(); p != nil {
err2 := tx.Rollback()
if err2 != nil {
err = err2
}
panic(p)
}
if ctx.Err() != nil {
// context was cancelled?
err = tx.Rollback()
return
}
if err != nil {
if inErr := tx.Rollback(); inErr != nil {
err = inErr
}
} else {
err = tx.Commit()
}
}()
err = txFunc(ctx, tx)
return err
}
func Transact(db Queryable, txFunc func(context.Context, Queryable) error) error {
return TransactContext(context.Background(), db, txFunc)
}
package types
import (
"bytes"
"compress/gzip"
"database/sql/driver"
"encoding/json"
"errors"
"io"
)
// GzippedText is a []byte which transparently gzips data being submitted to
// a database and ungzips data being Scanned from a database.
type GzippedText []byte
// Value implements the driver.Valuer interface, gzipping the raw value of
// this GzippedText.
func (g GzippedText) Value() (driver.Value, error) {
b := make([]byte, 0, len(g))
buf := bytes.NewBuffer(b)
w := gzip.NewWriter(buf)
w.Write(g)
w.Close()
return buf.Bytes(), nil
}
// Scan implements the sql.Scanner interface, ungzipping the value coming off
// the wire and storing the raw result in the GzippedText.
func (g *GzippedText) Scan(src interface{}) error {
var source []byte
switch src := src.(type) {
case string:
source = []byte(src)
case []byte:
source = src
default:
//lint:ignore ST1005 changing this could break consumers of this package
return errors.New("Incompatible type for GzippedText")
}
reader, err := gzip.NewReader(bytes.NewReader(source))
if err != nil {
return err
}
defer reader.Close()
b, err := io.ReadAll(reader)
if err != nil {
return err
}
*g = GzippedText(b)
return nil
}
// JSONText is a json.RawMessage, which is a []byte underneath.
// Value() validates the json format in the source, and returns an error if
// the json is not valid. Scan does no validation. JSONText additionally
// implements `Unmarshal`, which unmarshals the json within to an interface{}
type JSONText json.RawMessage
var emptyJSON = JSONText("{}")
// MarshalJSON returns the *j as the JSON encoding of j.
func (j JSONText) MarshalJSON() ([]byte, error) {
if len(j) == 0 {
return emptyJSON, nil
}
return j, nil
}
// UnmarshalJSON sets *j to a copy of data
func (j *JSONText) UnmarshalJSON(data []byte) error {
if j == nil {
return errors.New("JSONText: UnmarshalJSON on nil pointer")
}
*j = append((*j)[0:0], data...)
return nil
}
// Value returns j as a value. This does a validating unmarshal into another
// RawMessage. If j is invalid json, it returns an error.
func (j JSONText) Value() (driver.Value, error) {
var m json.RawMessage
var err = j.Unmarshal(&m)
if err != nil {
return []byte{}, err
}
return []byte(j), nil
}
// Scan stores the src in *j. No validation is done.
func (j *JSONText) Scan(src interface{}) error {
var source []byte
switch t := src.(type) {
case string:
source = []byte(t)
case []byte:
if len(t) == 0 {
source = emptyJSON
} else {
source = t
}
case nil:
*j = emptyJSON
default:
//lint:ignore ST1005 changing this could break consumers of this package
return errors.New("Incompatible type for JSONText")
}
*j = append((*j)[0:0], source...)
return nil
}
// Unmarshal unmarshal's the json in j to v, as in json.Unmarshal.
func (j *JSONText) Unmarshal(v interface{}) error {
if len(*j) == 0 {
*j = emptyJSON
}
return json.Unmarshal([]byte(*j), v)
}
// String supports pretty printing for JSONText types.
func (j JSONText) String() string {
return string(j)
}
// NullJSONText represents a JSONText that may be null.
// NullJSONText implements the scanner interface so
// it can be used as a scan destination, similar to NullString.
type NullJSONText struct {
JSONText
Valid bool // Valid is true if JSONText is not NULL
}
// Scan implements the Scanner interface.
func (n *NullJSONText) Scan(value interface{}) error {
if value == nil {
n.JSONText, n.Valid = emptyJSON, false
return nil
}
n.Valid = true
return n.JSONText.Scan(value)
}
// Value implements the driver Valuer interface.
func (n NullJSONText) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.JSONText.Value()
}
// BitBool is an implementation of a bool for the MySQL type BIT(1).
// This type allows you to avoid wasting an entire byte for MySQL's boolean type TINYINT.
type BitBool bool
// Value implements the driver.Valuer interface,
// and turns the BitBool into a bitfield (BIT(1)) for MySQL storage.
func (b BitBool) Value() (driver.Value, error) {
if b {
return []byte{1}, nil
}
return []byte{0}, nil
}
// Scan implements the sql.Scanner interface,
// and turns the bitfield incoming from MySQL into a BitBool
func (b *BitBool) Scan(src interface{}) error {
v, ok := src.([]byte)
if !ok {
return errors.New("bad []byte type assertion")
}
*b = v[0] == 1
return nil
}