package sqlx import ( "bytes" "database/sql/driver" "errors" "reflect" "strconv" "strings" "sync" "github.com/muir/sqltoken" "github.com/vinovest/sqlx/reflectx" ) // Bindvar types supported by Rebind, BindMap and BindStruct. const ( UNKNOWN = iota QUESTION DOLLAR NAMED AT ) var defaultBinds = map[int][]string{ DOLLAR: {"postgres", "pgx", "pq-timeouts", "cloudsqlpostgres", "ql", "nrpostgres", "cockroach"}, QUESTION: {"mysql", "sqlite3", "nrmysql", "nrsqlite3"}, NAMED: {"oci8", "ora", "goracle", "godror"}, AT: {"sqlserver", "azuresql"}, } var binds sync.Map var rebindConfigs = func() []sqltoken.Config { configs := make([]sqltoken.Config, AT+1) pg := sqltoken.PostgreSQLConfig() pg.NoticeQuestionMark = true pg.NoticeDollarNumber = false pg.SeparatePunctuation = true configs[DOLLAR] = pg ora := sqltoken.OracleConfig() ora.NoticeColonWord = false ora.NoticeQuestionMark = true ora.SeparatePunctuation = true configs[NAMED] = ora ssvr := sqltoken.SQLServerConfig() ssvr.NoticeAtWord = false ssvr.NoticeQuestionMark = true ssvr.SeparatePunctuation = true configs[AT] = ssvr return configs }() func init() { for bind, drivers := range defaultBinds { for _, driver := range drivers { BindDriver(driver, bind) } } } // BindType returns the bindtype for a given database given a drivername. func BindType(driverName string) int { itype, ok := binds.Load(driverName) if !ok { return UNKNOWN } return itype.(int) } // BindDriver sets the BindType for driverName to bindType. func BindDriver(driverName string, bindType int) { binds.Store(driverName, bindType) } // Rebind a query from the default bindtype (QUESTION) to the target bindtype. func Rebind(bindType int, query string) string { switch bindType { case QUESTION, UNKNOWN: return query } config := rebindConfigs[bindType] tokens := sqltoken.Tokenize(query, config) rqb := make([]byte, 0, len(query)+10) var j int for _, token := range tokens { if token.Type != sqltoken.QuestionMark { rqb = append(rqb, ([]byte)(token.Text)...) continue } switch bindType { case DOLLAR: rqb = append(rqb, '$') case NAMED: rqb = append(rqb, ':', 'a', 'r', 'g') case AT: rqb = append(rqb, '@', 'p') } j++ rqb = strconv.AppendInt(rqb, int64(j), 10) } return string(rqb) } // Experimental implementation of Rebind which uses a bytes.Buffer. The code is // much simpler and should be more resistant to odd unicode, but it is twice as // slow. Kept here for benchmarking purposes and to possibly replace Rebind if // problems arise with its somewhat naive handling of unicode. func rebindBuff(bindType int, query string) string { if bindType != DOLLAR { return query } b := make([]byte, 0, len(query)) rqb := bytes.NewBuffer(b) j := 1 for _, r := range query { if r == '?' { rqb.WriteRune('$') rqb.WriteString(strconv.Itoa(j)) j++ } else { rqb.WriteRune(r) } } return rqb.String() } func asSliceForIn(i interface{}) (v reflect.Value, ok bool) { if i == nil { return reflect.Value{}, false } v = reflect.ValueOf(i) t := reflectx.Deref(v.Type()) // Only expand slices if t.Kind() != reflect.Slice { return reflect.Value{}, false } // []byte is a driver.Value type so it should not be expanded if t == reflect.TypeOf([]byte{}) { return reflect.Value{}, false } return v, true } // In expands slice values in args, returning the modified query string // and a new arg list that can be executed by a database. The `query` should // use the `?` bindVar. The return value uses the `?` bindVar. func In(query string, args ...interface{}) (string, []interface{}, error) { // argMeta stores reflect.Value and length for slices and // the value itself for non-slice arguments type argMeta struct { v reflect.Value i interface{} length int } var flatArgsCount int var anySlices bool var stackMeta [32]argMeta var meta []argMeta if len(args) <= len(stackMeta) { meta = stackMeta[:len(args)] } else { meta = make([]argMeta, len(args)) } for i, arg := range args { if a, ok := arg.(driver.Valuer); ok { var err error arg, err = callValuerValue(a) if err != nil { return "", nil, err } } if v, ok := asSliceForIn(arg); ok { meta[i].length = v.Len() meta[i].v = v anySlices = true flatArgsCount += meta[i].length if meta[i].length == 0 { return "", nil, errors.New("empty slice passed to 'in' query") } } else { meta[i].i = arg flatArgsCount++ } } // don't do any parsing if there aren't any slices; note that this means // some errors that we might have caught below will not be returned. if !anySlices { return query, args, nil } newArgs := make([]interface{}, 0, flatArgsCount) var buf strings.Builder buf.Grow(len(query) + len(", ?")*flatArgsCount) var arg, offset int for i := strings.IndexByte(query[offset:], '?'); i != -1; i = strings.IndexByte(query[offset:], '?') { if arg >= len(meta) { // if an argument wasn't passed, lets return an error; this is // not actually how database/sql Exec/Query works, but since we are // creating an argument list programmatically, we want to be able // to catch these programmer errors earlier. return "", nil, errors.New("number of bindVars exceeds arguments") } argMeta := meta[arg] arg++ // not a slice, continue. // our questionmark will either be written before the next expansion // of a slice or after the loop when writing the rest of the query if argMeta.length == 0 { offset = offset + i + 1 newArgs = append(newArgs, argMeta.i) continue } // write everything up to and including our ? character buf.WriteString(query[:offset+i+1]) for si := 1; si < argMeta.length; si++ { buf.WriteString(", ?") } newArgs = appendReflectSlice(newArgs, argMeta.v, argMeta.length) // slice the query and reset the offset. this avoids some bookkeeping for // the write after the loop query = query[offset+i+1:] offset = 0 } buf.WriteString(query) if arg < len(meta) { return "", nil, errors.New("number of bindVars less than number arguments") } return buf.String(), newArgs, nil } func appendReflectSlice(args []interface{}, v reflect.Value, vlen int) []interface{} { switch val := v.Interface().(type) { case []interface{}: args = append(args, val...) case []int: for i := range val { args = append(args, val[i]) } case []string: for i := range val { args = append(args, val[i]) } default: for si := 0; si < vlen; si++ { args = append(args, v.Index(si).Interface()) } } return args } // callValuerValue returns vr.Value(), with one exception: // If vr.Value is an auto-generated method on a pointer type and the // pointer is nil, it would panic at runtime in the panicwrap // method. Treat it like nil instead. // Issue 8415. // // This is so people can implement driver.Value on value types and // still use nil pointers to those types to mean nil/NULL, just like // string/*string. // // This function is copied from database/sql/driver package // and mirrored in the database/sql package. func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Pointer && rv.IsNil() && rv.Type().Elem().Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) { return nil, nil } return vr.Value() }
package sqlx // Named Query Support // // * BindMap - bind query bindvars to map/struct args // * NamedExec, NamedQuery - named query w/ struct or map // * NamedStmt - a pre-compiled named query which is a prepared statement // // Internal Interfaces: // // * compileNamedQuery - rebind a named query, returning a query and list of names // * bindArgs, bindMapArgs, bindAnyArgs - given a list of names, return an arglist // import ( "bytes" "database/sql" "fmt" "reflect" "strconv" "strings" "github.com/muir/sqltoken" "github.com/vinovest/sqlx/reflectx" ) // NamedStmt is a prepared statement that executes named queries. Prepare it // how you would execute a NamedQuery, but pass in a struct or map when executing. type NamedStmt struct { Params []string QueryString string Stmt *Stmt } // Close closes the named statement. func (n *NamedStmt) Close() error { return n.Stmt.Close() } // Exec executes a named statement using the struct passed. // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) Exec(arg interface{}) (sql.Result, error) { args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) if err != nil { return *new(sql.Result), err } return n.Stmt.Exec(args...) } // Query executes a named statement using the struct argument, returning rows. // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) Query(arg interface{}) (*sql.Rows, error) { args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) if err != nil { return nil, err } return n.Stmt.Query(args...) } // QueryRow executes a named statement against the database. Because sqlx cannot // create a *sql.Row with an error condition pre-set for binding errors, sqlx // returns a *sqlx.Row instead. // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) QueryRow(arg interface{}) *Row { args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) if err != nil { return &Row{err: err} } return n.Stmt.QueryRowx(args...) } // MustExec execs a NamedStmt, panicing on error // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) MustExec(arg interface{}) sql.Result { res, err := n.Exec(arg) if err != nil { panic(err) } return res } // Queryx using this NamedStmt // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) Queryx(arg interface{}) (*Rows, error) { r, err := n.Query(arg) if err != nil { return nil, err } return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err } // QueryRowx this NamedStmt. Because of limitations with QueryRow, this is // an alias for QueryRow. // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) QueryRowx(arg interface{}) *Row { return n.QueryRow(arg) } // Select using this NamedStmt // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) Select(dest interface{}, arg interface{}) error { rows, err := n.Queryx(arg) if err != nil { return err } // if something happens here, we want to make sure the rows are Closed defer rows.Close() return scanAll(rows, dest, false) } // Get using this NamedStmt // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) Get(dest interface{}, arg interface{}) error { r := n.QueryRowx(arg) return r.scanAny(dest, false) } // Unsafe creates an unsafe version of the NamedStmt func (n *NamedStmt) Unsafe() *NamedStmt { r := &NamedStmt{Params: n.Params, Stmt: n.Stmt, QueryString: n.QueryString} r.Stmt.unsafe = true return r } // A union interface of preparer and binder, required to be able to prepare // named statements (as the bindtype must be determined). type namedPreparer interface { Preparer binder } func prepareNamed(p namedPreparer, query string) (*NamedStmt, error) { bindType := BindType(p.DriverName()) compiled, err := compileNamedQuery([]byte(query), bindType) if err != nil { return nil, err } stmt, err := Preparex(p, compiled.query) if err != nil { return nil, err } return &NamedStmt{ QueryString: compiled.query, Params: compiled.names, Stmt: stmt, }, nil } // convertMapStringInterface attempts to convert v to map[string]interface{}. // Unlike v.(map[string]interface{}), this function works on named types that // are convertible to map[string]interface{} as well. func convertMapStringInterface(v interface{}) (map[string]interface{}, bool) { var m map[string]interface{} mtype := reflect.TypeOf(m) t := reflect.TypeOf(v) if !t.ConvertibleTo(mtype) { return nil, false } return reflect.ValueOf(v).Convert(mtype).Interface().(map[string]interface{}), true } func bindAnyArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) { if maparg, ok := convertMapStringInterface(arg); ok { return bindMapArgs(names, maparg) } return bindArgs(names, arg, m) } // private interface to generate a list of interfaces from a given struct // type, given a list of names to pull out of the struct. Used by public // BindStruct interface. func bindArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) { arglist := make([]interface{}, 0, len(names)) // grab the indirected value of arg var v reflect.Value for v = reflect.ValueOf(arg); v.Kind() == reflect.Ptr; { v = v.Elem() } err := m.TraversalsByNameFunc(v.Type(), names, func(i int, t []int) error { if len(t) == 0 { return fmt.Errorf("could not find name %s in %#v", names[i], arg) } val := reflectx.FieldByIndexesReadOnly(v, t) arglist = append(arglist, val.Interface()) return nil }) return arglist, err } // like bindArgs, but for maps. func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, error) { arglist := make([]interface{}, 0, len(names)) for _, name := range names { val, ok := arg[name] if !ok { return arglist, fmt.Errorf("could not find name %s in %#v", name, arg) } arglist = append(arglist, val) } return arglist, nil } // bindStruct binds a named parameter query with fields from a struct argument. // The rules for binding field names to parameter names follow the same // conventions as for StructScan, including obeying the `db` struct tags. func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { compiled, err := compileNamedQuery([]byte(query), bindType) if err != nil { return "", []interface{}{}, err } arglist, err := bindAnyArgs(compiled.names, arg, m) if err != nil { return "", []interface{}{}, err } return compiled.query, arglist, nil } func fixBound(cq *compiledQueryResult, loop int) { if cq.valuesStart == nil || cq.valuesEnd == nil { return } buffer := bytes.NewBuffer(make([]byte, 0, (int(*cq.valuesEnd-1)-int(*cq.valuesStart))*loop+ // bytes for commas, too (loop-1)+ // plus the query (len(cq.query)-int(*cq.valuesEnd)+int(*cq.valuesStart)))) buffer.WriteString(cq.query[0:*cq.valuesEnd]) for i := 0; i < loop-1; i++ { buffer.WriteString(",") buffer.WriteString(cq.query[*cq.valuesStart:*cq.valuesEnd]) } buffer.WriteString(cq.query[*cq.valuesEnd:]) cq.query = buffer.String() } // bindArray binds a named parameter query with fields from an array or slice of // structs argument. func bindArray(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { // do the initial binding with QUESTION; if bindType is not question, // we can rebind it at the end. compiled, err := compileNamedQuery([]byte(query), QUESTION) if err != nil { return "", []interface{}{}, err } arrayValue := reflect.ValueOf(arg) arrayLen := arrayValue.Len() if arrayLen == 0 { return "", []interface{}{}, fmt.Errorf("length of array is 0: %#v", arg) } arglist := make([]interface{}, 0, len(compiled.names)*arrayLen) for i := 0; i < arrayLen; i++ { elemArglist, err := bindAnyArgs(compiled.names, arrayValue.Index(i).Interface(), m) if err != nil { return "", []interface{}{}, err } arglist = append(arglist, elemArglist...) } if arrayLen > 1 { fixBound(&compiled, arrayLen) } // adjust binding type if we weren't on question bound := compiled.query if bindType != QUESTION { bound = Rebind(bindType, bound) } return bound, arglist, nil } // bindMap binds a named parameter query with a map of arguments. func bindMap(bindType int, query string, args map[string]interface{}) (string, []interface{}, error) { compiled, err := compileNamedQuery([]byte(query), bindType) if err != nil { return "", []interface{}{}, err } arglist, err := bindMapArgs(compiled.names, args) return compiled.query, arglist, err } var namedParseConfigs = func() []sqltoken.Config { configs := make([]sqltoken.Config, AT+1) pg := sqltoken.PostgreSQLConfig() pg.NoticeColonWord = true pg.ColonWordIncludesUnicode = true pg.NoticeDollarNumber = false pg.NoticeQuestionMark = true pg.SeparatePunctuation = true configs[DOLLAR] = pg ora := sqltoken.OracleConfig() ora.ColonWordIncludesUnicode = true ora.NoticeQuestionMark = true ora.SeparatePunctuation = true configs[NAMED] = ora ssvr := sqltoken.SQLServerConfig() ssvr.NoticeColonWord = true ssvr.ColonWordIncludesUnicode = true ssvr.NoticeAtWord = false ssvr.SeparatePunctuation = true configs[AT] = ssvr mysql := sqltoken.MySQLConfig() mysql.NoticeColonWord = true mysql.ColonWordIncludesUnicode = true mysql.NoticeQuestionMark = true mysql.SeparatePunctuation = true configs[QUESTION] = mysql configs[UNKNOWN] = mysql return configs }() // -- Compilation of Named Queries // FIXME: this function isn't safe for unicode named params, as a failing test // can testify. This is not a regression but a failure of the original code // as well. It should be modified to range over runes in a string rather than // bytes, even though this is less convenient and slower. Hopefully the // addition of the prepared NamedStmt (which will only do this once) will make // up for the slightly slower ad-hoc NamedExec/NamedQuery. type compiledQueryResult struct { // the query string with the named parameters swapped out with bindvars query string // the name of the parameter names []string // if set, the start position of the VALUES argument list not including () (end inclusive) valuesStart *uint32 valuesEnd *uint32 } // compile a NamedQuery into an unbound query (using the '?' bindvar) and // a list of names. func compileNamedQuery(qs []byte, bindType int) (compiledQueryResult, error) { r := compiledQueryResult{ names: make([]string, 0, 10), } curpos := uint32(0) rebound := make([]byte, 0, len(qs)) inValues := false inValuesOpenCount := 0 currentVar := 1 tokens := sqltoken.Tokenize(string(qs), namedParseConfigs[bindType]) for _, token := range tokens { if token.Type == sqltoken.Word && strings.EqualFold("values", token.Text) && !inValues { // current behavior: expand the first values and ignore the rest if r.valuesStart == nil { // did we already parse a values statement? inValues = true } } if inValues && token.Type == sqltoken.Punctuation { if token.Text == "(" { start := curpos inValuesOpenCount += 1 if r.valuesStart == nil { r.valuesStart = &start } } if token.Text == ")" { inValuesOpenCount -= 1 if inValuesOpenCount == 0 { end := curpos + 1 r.valuesEnd = &end inValues = false } } } if token.Type != sqltoken.ColonWord { rebound = append(rebound, ([]byte)(token.Text)...) curpos += uint32(len(token.Text)) continue } r.names = append(r.names, token.Text[1:]) newBound := "" switch bindType { // oracle only supports named type bind vars even for positional case NAMED: newBound = token.Text case QUESTION, UNKNOWN: newBound = "?" case DOLLAR: newBound = "$" + strconv.Itoa(currentVar) currentVar++ case AT: newBound = "@p" + strconv.Itoa(currentVar) currentVar++ } rebound = append(rebound, []byte(newBound)...) curpos += uint32(len(newBound)) } if inValues { return r, fmt.Errorf("missing closing bracket in VALUES") } r.query = string(rebound) return r, nil } // BindNamed binds a struct or a map to a query with named parameters. // DEPRECATED: use sqlx.Named` instead of this, it may be removed in future. func BindNamed(bindType int, query string, arg interface{}) (string, []interface{}, error) { return bindNamedMapper(bindType, query, arg, mapper()) } // Named takes a query using named parameters and an argument and // returns a new query with a list of args that can be executed by // a database. The return value uses the `?` bindvar. func Named(query string, arg interface{}) (string, []interface{}, error) { return bindNamedMapper(QUESTION, query, arg, mapper()) } func bindNamedMapper(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { t := reflect.TypeOf(arg) k := t.Kind() switch { case k == reflect.Map && t.Key().Kind() == reflect.String: m, ok := convertMapStringInterface(arg) if !ok { return "", nil, fmt.Errorf("sqlx.bindNamedMapper: unsupported map type: %T", arg) } return bindMap(bindType, query, m) case k == reflect.Array || k == reflect.Slice: return bindArray(bindType, query, arg, m) default: return bindStruct(bindType, query, arg, m) } } // NamedQuery binds a named query and then runs Query on the result using the // provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with // map[string]interface{} types. func NamedQuery(e Ext, query string, arg interface{}) (*Rows, error) { q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) if err != nil { return nil, err } return e.Queryx(q, args...) } // NamedExec uses BindStruct to get a query executable by the driver and // then runs Exec on the result. Returns an error from the binding // or the query execution itself. func NamedExec(e Ext, query string, arg interface{}) (sql.Result, error) { q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) if err != nil { return nil, err } return e.Exec(q, args...) }
package sqlx import ( "context" "database/sql" ) // A union interface of contextPreparer and binder, required to be able to // prepare named statements with context (as the bindtype must be determined). type namedPreparerContext interface { PreparerContext binder } func prepareNamedContext(ctx context.Context, p namedPreparerContext, query string) (*NamedStmt, error) { bindType := BindType(p.DriverName()) compiled, err := compileNamedQuery([]byte(query), bindType) if err != nil { return nil, err } stmt, err := PreparexContext(ctx, p, compiled.query) if err != nil { return nil, err } return &NamedStmt{ QueryString: compiled.query, Params: compiled.names, Stmt: stmt, }, nil } // ExecContext executes a named statement using the struct passed. // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) ExecContext(ctx context.Context, arg interface{}) (sql.Result, error) { args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) if err != nil { return *new(sql.Result), err } return n.Stmt.ExecContext(ctx, args...) } // QueryContext executes a named statement using the struct argument, returning rows. // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) QueryContext(ctx context.Context, arg interface{}) (*sql.Rows, error) { args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) if err != nil { return nil, err } return n.Stmt.QueryContext(ctx, args...) } // QueryRowContext executes a named statement against the database. Because sqlx cannot // create a *sql.Row with an error condition pre-set for binding errors, sqlx // returns a *sqlx.Row instead. // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) QueryRowContext(ctx context.Context, arg interface{}) *Row { args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) if err != nil { return &Row{err: err} } return n.Stmt.QueryRowxContext(ctx, args...) } // MustExecContext execs a NamedStmt, panicing on error // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) MustExecContext(ctx context.Context, arg interface{}) sql.Result { res, err := n.ExecContext(ctx, arg) if err != nil { panic(err) } return res } // QueryxContext using this NamedStmt // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) QueryxContext(ctx context.Context, arg interface{}) (*Rows, error) { r, err := n.QueryContext(ctx, arg) if err != nil { return nil, err } return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err } // QueryRowxContext this NamedStmt. Because of limitations with QueryRow, this is // an alias for QueryRow. // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) QueryRowxContext(ctx context.Context, arg interface{}) *Row { return n.QueryRowContext(ctx, arg) } // SelectContext using this NamedStmt // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) SelectContext(ctx context.Context, dest interface{}, arg interface{}) error { rows, err := n.QueryxContext(ctx, arg) if err != nil { return err } // if something happens here, we want to make sure the rows are Closed defer rows.Close() return scanAll(rows, dest, false) } // GetContext using this NamedStmt // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) GetContext(ctx context.Context, dest interface{}, arg interface{}) error { r := n.QueryRowxContext(ctx, arg) return r.scanAny(dest, false) } // NamedQueryContext binds a named query and then runs Query on the result using the // provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with // map[string]interface{} types. func NamedQueryContext(ctx context.Context, e ExtContext, query string, arg interface{}) (*Rows, error) { q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) if err != nil { return nil, err } return e.QueryxContext(ctx, q, args...) } // NamedExecContext uses BindStruct to get a query executable by the driver and // then runs Exec on the result. Returns an error from the binding // or the query execution itself. func NamedExecContext(ctx context.Context, e ExtContext, query string, arg interface{}) (sql.Result, error) { q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) if err != nil { return nil, err } return e.ExecContext(ctx, q, args...) }
// Package reflectx implements extensions to the standard reflect lib suitable // for implementing marshalling and unmarshalling packages. The main Mapper type // allows for Go-compatible named attribute access, including accessing embedded // struct attributes and the ability to use functions and struct tags to // customize field names. package reflectx import ( "reflect" "runtime" "strings" "sync" ) // A FieldInfo is metadata for a struct field. type FieldInfo struct { Index []int Path string Field reflect.StructField Zero reflect.Value Name string Options map[string]string Embedded bool Children []*FieldInfo Parent *FieldInfo } // A StructMap is an index of field metadata for a struct. type StructMap struct { Tree *FieldInfo Index []*FieldInfo Paths map[string]*FieldInfo Names map[string]*FieldInfo } // GetByPath returns a *FieldInfo for a given string path. func (f StructMap) GetByPath(path string) *FieldInfo { return f.Paths[path] } // GetByTraversal returns a *FieldInfo for a given integer path. It is // analogous to reflect.FieldByIndex, but using the cached traversal // rather than re-executing the reflect machinery each time. func (f StructMap) GetByTraversal(index []int) *FieldInfo { if len(index) == 0 { return nil } tree := f.Tree for _, i := range index { if i >= len(tree.Children) || tree.Children[i] == nil { return nil } tree = tree.Children[i] } return tree } // Mapper is a general purpose mapper of names to struct fields. A Mapper // behaves like most marshallers in the standard library, obeying a field tag // for name mapping but also providing a basic transform function. type Mapper struct { cache map[reflect.Type]*StructMap tagName string tagMapFunc func(string) string mapFunc func(string) string mutex sync.Mutex } // NewMapper returns a new mapper using the tagName as its struct field tag. // If tagName is the empty string, it is ignored. func NewMapper(tagName string) *Mapper { return &Mapper{ cache: make(map[reflect.Type]*StructMap), tagName: tagName, } } // NewMapperTagFunc returns a new mapper which contains a mapper for field names // AND a mapper for tag values. This is useful for tags like json which can // have values like "name,omitempty". func NewMapperTagFunc(tagName string, mapFunc, tagMapFunc func(string) string) *Mapper { return &Mapper{ cache: make(map[reflect.Type]*StructMap), tagName: tagName, mapFunc: mapFunc, tagMapFunc: tagMapFunc, } } // NewMapperFunc returns a new mapper which optionally obeys a field tag and // a struct field name mapper func given by f. Tags will take precedence, but // for any other field, the mapped name will be f(field.Name) func NewMapperFunc(tagName string, f func(string) string) *Mapper { return &Mapper{ cache: make(map[reflect.Type]*StructMap), tagName: tagName, mapFunc: f, } } // TypeMap returns a mapping of field strings to int slices representing // the traversal down the struct to reach the field. func (m *Mapper) TypeMap(t reflect.Type) *StructMap { m.mutex.Lock() mapping, ok := m.cache[t] if !ok { mapping = getMapping(t, m.tagName, m.mapFunc, m.tagMapFunc) m.cache[t] = mapping } m.mutex.Unlock() return mapping } // FieldMap returns the mapper's mapping of field names to reflect values. Panics // if v's Kind is not Struct, or v is not Indirectable to a struct kind. func (m *Mapper) FieldMap(v reflect.Value) map[string]reflect.Value { v = reflect.Indirect(v) mustBe(v, reflect.Struct) r := map[string]reflect.Value{} tm := m.TypeMap(v.Type()) for tagName, fi := range tm.Names { r[tagName] = FieldByIndexes(v, fi.Index) } return r } // FieldByName returns a field by its mapped name as a reflect.Value. // Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind. // Returns zero Value if the name is not found. func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value { v = reflect.Indirect(v) mustBe(v, reflect.Struct) tm := m.TypeMap(v.Type()) fi, ok := tm.Names[name] if !ok { return v } return FieldByIndexes(v, fi.Index) } // FieldsByName returns a slice of values corresponding to the slice of names // for the value. Panics if v's Kind is not Struct or v is not Indirectable // to a struct Kind. Returns zero Value for each name not found. func (m *Mapper) FieldsByName(v reflect.Value, names []string) []reflect.Value { v = reflect.Indirect(v) mustBe(v, reflect.Struct) tm := m.TypeMap(v.Type()) vals := make([]reflect.Value, 0, len(names)) for _, name := range names { fi, ok := tm.Names[name] if !ok { vals = append(vals, *new(reflect.Value)) } else { vals = append(vals, FieldByIndexes(v, fi.Index)) } } return vals } // TraversalsByName returns a slice of int slices which represent the struct // traversals for each mapped name. Panics if t is not a struct or Indirectable // to a struct. Returns empty int slice for each name not found. func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int { r := make([][]int, 0, len(names)) m.TraversalsByNameFunc(t, names, func(_ int, i []int) error { if i == nil { r = append(r, []int{}) } else { r = append(r, i) } return nil }) return r } // TraversalsByNameFunc traverses the mapped names and calls fn with the index of // each name and the struct traversal represented by that name. Panics if t is not // a struct or Indirectable to a struct. Returns the first error returned by fn or nil. func (m *Mapper) TraversalsByNameFunc(t reflect.Type, names []string, fn func(int, []int) error) error { t = Deref(t) mustBe(t, reflect.Struct) tm := m.TypeMap(t) for i, name := range names { fi, ok := tm.Names[name] if !ok { if err := fn(i, nil); err != nil { return err } } else { if err := fn(i, fi.Index); err != nil { return err } } } return nil } // FieldByIndexes returns a value for the field given by the struct traversal // for the given value. func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { for _, i := range indexes { v = reflect.Indirect(v).Field(i) // if this is a pointer and it's nil, allocate a new value and set it if v.Kind() == reflect.Ptr && v.IsNil() { alloc := reflect.New(Deref(v.Type())) v.Set(alloc) } if v.Kind() == reflect.Map && v.IsNil() { v.Set(reflect.MakeMap(v.Type())) } } return v } // FieldByIndexesReadOnly returns a value for a particular struct traversal, // but is not concerned with allocating nil pointers because the value is // going to be used for reading and not setting. func FieldByIndexesReadOnly(v reflect.Value, indexes []int) reflect.Value { for _, i := range indexes { v = reflect.Indirect(v).Field(i) } return v } // Deref is Indirect for reflect.Types func Deref(t reflect.Type) reflect.Type { if t.Kind() == reflect.Ptr { t = t.Elem() } return t } // -- helpers & utilities -- type kinder interface { Kind() reflect.Kind } // mustBe checks a value against a kind, panicing with a reflect.ValueError // if the kind isn't that which is required. func mustBe(v kinder, expected reflect.Kind) { if k := v.Kind(); k != expected { panic(&reflect.ValueError{Method: methodName(), Kind: k}) } } // methodName returns the caller of the function calling methodName func methodName() string { pc, _, _, _ := runtime.Caller(2) f := runtime.FuncForPC(pc) if f == nil { return "unknown method" } return f.Name() } type typeQueue struct { t reflect.Type fi *FieldInfo pp string // Parent path } // A copying append that creates a new slice each time. func apnd(is []int, i int) []int { x := make([]int, len(is)+1) copy(x, is) x[len(x)-1] = i return x } type mapf func(string) string // parseName parses the tag and the target name for the given field using // the tagName (eg 'json' for `json:"foo"` tags), mapFunc for mapping the // field's name to a target name, and tagMapFunc for mapping the tag to // a target name. func parseName(field reflect.StructField, tagName string, mapFunc, tagMapFunc mapf) (tag, fieldName string) { // first, set the fieldName to the field's name fieldName = field.Name // if a mapFunc is set, use that to override the fieldName if mapFunc != nil { fieldName = mapFunc(fieldName) } // if there's no tag to look for, return the field name if tagName == "" { return "", fieldName } // if this tag is not set using the normal convention in the tag, // then return the fieldname.. this check is done because according // to the reflect documentation: // If the tag does not have the conventional format, // the value returned by Get is unspecified. // which doesn't sound great. if !strings.Contains(string(field.Tag), tagName+":") { return "", fieldName } // at this point we're fairly sure that we have a tag, so lets pull it out tag = field.Tag.Get(tagName) // if we have a mapper function, call it on the whole tag // XXX: this is a change from the old version, which pulled out the name // before the tagMapFunc could be run, but I think this is the right way if tagMapFunc != nil { tag = tagMapFunc(tag) } // finally, split the options from the name parts := strings.Split(tag, ",") fieldName = parts[0] return tag, fieldName } // parseOptions parses options out of a tag string, skipping the name func parseOptions(tag string) map[string]string { parts := strings.Split(tag, ",") options := make(map[string]string, len(parts)) if len(parts) > 1 { for _, opt := range parts[1:] { // short circuit potentially expensive split op if strings.Contains(opt, "=") { kv := strings.Split(opt, "=") options[kv[0]] = kv[1] continue } options[opt] = "" } } return options } // getMapping returns a mapping for the t type, using the tagName, mapFunc and // tagMapFunc to determine the canonical names of fields. func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc mapf) *StructMap { m := []*FieldInfo{} root := &FieldInfo{} queue := []typeQueue{} queue = append(queue, typeQueue{Deref(t), root, ""}) QueueLoop: for len(queue) != 0 { // pop the first item off of the queue tq := queue[0] queue = queue[1:] // ignore recursive field for p := tq.fi.Parent; p != nil; p = p.Parent { if tq.fi.Field.Type == p.Field.Type { continue QueueLoop } } nChildren := 0 if tq.t.Kind() == reflect.Struct { nChildren = tq.t.NumField() } tq.fi.Children = make([]*FieldInfo, nChildren) // iterate through all of its fields for fieldPos := 0; fieldPos < nChildren; fieldPos++ { f := tq.t.Field(fieldPos) // parse the tag and the target name using the mapping options for this field tag, name := parseName(f, tagName, mapFunc, tagMapFunc) // if the name is "-", disabled via a tag, skip it if name == "-" { continue } fi := FieldInfo{ Field: f, Name: name, Zero: reflect.New(f.Type).Elem(), Options: parseOptions(tag), } // if the path is empty this path is just the name if tq.pp == "" { fi.Path = fi.Name } else { fi.Path = tq.pp + "." + fi.Name } // skip unexported fields if len(f.PkgPath) != 0 && !f.Anonymous { continue } // bfs search of anonymous embedded structs if f.Anonymous { pp := tq.pp if tag != "" { pp = fi.Path } fi.Embedded = true fi.Index = apnd(tq.fi.Index, fieldPos) nChildren := 0 ft := Deref(f.Type) if ft.Kind() == reflect.Struct { nChildren = ft.NumField() } fi.Children = make([]*FieldInfo, nChildren) queue = append(queue, typeQueue{Deref(f.Type), &fi, pp}) } else if fi.Zero.Kind() == reflect.Struct || (fi.Zero.Kind() == reflect.Ptr && fi.Zero.Type().Elem().Kind() == reflect.Struct) { fi.Index = apnd(tq.fi.Index, fieldPos) fi.Children = make([]*FieldInfo, Deref(f.Type).NumField()) queue = append(queue, typeQueue{Deref(f.Type), &fi, fi.Path}) } fi.Index = apnd(tq.fi.Index, fieldPos) fi.Parent = tq.fi tq.fi.Children[fieldPos] = &fi m = append(m, &fi) } } flds := &StructMap{Index: m, Tree: root, Paths: map[string]*FieldInfo{}, Names: map[string]*FieldInfo{}} for _, fi := range flds.Index { // check if nothing has already been pushed with the same path // sometimes you can choose to override a type using embedded struct fld, ok := flds.Paths[fi.Path] if !ok || fld.Embedded { flds.Paths[fi.Path] = fi if fi.Name != "" && !fi.Embedded { flds.Names[fi.Path] = fi } } } return flds }
package sqlx import ( "context" "database/sql" "database/sql/driver" "errors" "fmt" "os" "path/filepath" "reflect" "strings" "sync" "github.com/vinovest/sqlx/reflectx" ) // ErrMultiRows is returned by functions which are expected to work with result sets // that only contain a single row but multiple rows were returned. // This typically indicates an issue with the query such as a missing join criteria or // limit condition or the use of Get(...) when Select(...) was intended. var ErrMultiRows = errors.New("sql: multiple rows returned") // Although the NameMapper is convenient, in practice it should not // be relied on except for application code. If you are writing a library // that uses sqlx, you should be aware that the name mappings you expect // can be overridden by your user's application. // NameMapper is used to map column names to struct field names. By default, // it uses strings.ToLower to lowercase struct field names. It can be set // to whatever you want, but it is encouraged to be set before sqlx is used // as name-to-field mappings are cached after first use on a type. var NameMapper = strings.ToLower var origMapper = reflect.ValueOf(NameMapper) // Rather than creating on init, this is created when necessary so that // importers have time to customize the NameMapper. var mpr *reflectx.Mapper // mprMu protects mpr. var mprMu sync.Mutex // mapper returns a valid mapper using the configured NameMapper func. func mapper() *reflectx.Mapper { mprMu.Lock() defer mprMu.Unlock() if mpr == nil { mpr = reflectx.NewMapperFunc("db", NameMapper) } else if origMapper != reflect.ValueOf(NameMapper) { // if NameMapper has changed, create a new mapper mpr = reflectx.NewMapperFunc("db", NameMapper) origMapper = reflect.ValueOf(NameMapper) } return mpr } // isScannable takes the reflect.Type and the actual dest value and returns // whether or not it's Scannable. Something is scannable if: // - it is not a struct // - it implements sql.Scanner // - it has no exported fields func isScannable(t reflect.Type) bool { if reflect.PointerTo(t).Implements(_scannerInterface) { return true } if t.Kind() != reflect.Struct { return true } // it's not important that we use the right mapper for this particular object, // we're only concerned on how many exported fields this struct has return len(mapper().TypeMap(t).Index) == 0 } // ColScanner is an interface used by MapScan and SliceScan type ColScanner interface { Columns() ([]string, error) Scan(dest ...interface{}) error Err() error } // Queryer is an interface used by Get and Select type Queryer interface { Query(query string, args ...interface{}) (*sql.Rows, error) Queryx(query string, args ...interface{}) (*Rows, error) QueryRowx(query string, args ...interface{}) *Row } // Execer is an interface used by MustExec and LoadFile type Execer interface { Exec(query string, args ...interface{}) (sql.Result, error) } // Binder is an interface for something which can bind queries (Tx, DB) type binder interface { DriverName() string Rebind(string) string BindNamed(string, interface{}) (string, []interface{}, error) } // Ext is a union interface which can bind, query, and exec, used by // NamedQuery and NamedExec. type Ext interface { binder Queryer Execer } // Preparer is an interface used by Preparex. type Preparer interface { Prepare(query string) (*sql.Stmt, error) } // determine if any of our extensions are unsafe func isUnsafe(i interface{}) bool { switch v := i.(type) { case Row: return v.unsafe case *Row: return v.unsafe case Rows: return v.unsafe case *Rows: return v.unsafe case NamedStmt: return v.Stmt.unsafe case *NamedStmt: return v.Stmt.unsafe case Stmt: return v.unsafe case *Stmt: return v.unsafe case qStmt: return v.unsafe case *qStmt: return v.unsafe case DB: return v.unsafe case *DB: return v.unsafe case Tx: return v.unsafe case *Tx: return v.unsafe case sql.Rows, *sql.Rows: return false default: return false } } func mapperFor(i interface{}) *reflectx.Mapper { switch i := i.(type) { case DB: return i.Mapper case *DB: return i.Mapper case Tx: return i.Mapper case *Tx: return i.Mapper default: return mapper() } } var _scannerInterface = reflect.TypeOf((*sql.Scanner)(nil)).Elem() //lint:ignore U1000 ignoring this for now var _valuerInterface = reflect.TypeOf((*driver.Valuer)(nil)).Elem() // Row is a reimplementation of sql.Row in order to gain access to the underlying // sql.Rows.Columns() data, necessary for StructScan. type Row struct { err error unsafe bool rows *sql.Rows Mapper *reflectx.Mapper } // Scan is a fixed implementation of sql.Row.Scan, which does not discard the // underlying error from the internal rows object if it exists. // Returns ErrMultiRows if the result set contains more than one row. func (r *Row) Scan(dest ...interface{}) error { if r.err != nil { return r.err } // TODO(bradfitz): for now we need to defensively clone all // []byte that the driver returned (not permitting // *RawBytes in Rows.Scan), since we're about to close // the Rows in our defer, when we return from this function. // the contract with the driver.Next(...) interface is that it // can return slices into read-only temporary memory that's // only valid until the next Scan/Close. But the TODO is that // for a lot of drivers, this copy will be unnecessary. We // should provide an optional interface for drivers to // implement to say, "don't worry, the []bytes that I return // from Next will not be modified again." (for instance, if // they were obtained from the network anyway) But for now we // don't care. defer r.rows.Close() for _, dp := range dest { if _, ok := dp.(*sql.RawBytes); ok { return errors.New("sql: RawBytes isn't allowed on Row.Scan") } } if !r.rows.Next() { if err := r.rows.Err(); err != nil { return err } return sql.ErrNoRows } if err := r.rows.Scan(dest...); err != nil { return err } if r.rows.Next() { return ErrMultiRows } else if err := r.rows.Err(); err != nil { return err } // Make sure the query can be processed to completion with no errors. if err := r.rows.Close(); err != nil { return err } return nil } // Columns returns the underlying sql.Rows.Columns(), or the deferred error usually // returned by Row.Scan() func (r *Row) Columns() ([]string, error) { if r.err != nil { return []string{}, r.err } return r.rows.Columns() } // ColumnTypes returns the underlying sql.Rows.ColumnTypes(), or the deferred error func (r *Row) ColumnTypes() ([]*sql.ColumnType, error) { if r.err != nil { return []*sql.ColumnType{}, r.err } return r.rows.ColumnTypes() } // Err returns the error encountered while scanning. func (r *Row) Err() error { return r.err } // Queryable includes all methods shared by sqlx.DB and sqlx.Tx, allowing // either type to be used interchangeably. type Queryable interface { Ext ExecerContext PreparerContext QueryerContext Preparer GetContext(context.Context, interface{}, string, ...interface{}) error SelectContext(context.Context, interface{}, string, ...interface{}) error Get(interface{}, string, ...interface{}) error MustExecContext(context.Context, string, ...interface{}) sql.Result PreparexContext(context.Context, string) (*Stmt, error) QueryRowContext(context.Context, string, ...interface{}) *sql.Row Select(interface{}, string, ...interface{}) error QueryRow(string, ...interface{}) *sql.Row PrepareNamedContext(context.Context, string) (*NamedStmt, error) PrepareNamed(string) (*NamedStmt, error) Preparex(string) (*Stmt, error) NamedExec(string, interface{}) (sql.Result, error) NamedExecContext(context.Context, string, interface{}) (sql.Result, error) MustExec(string, ...interface{}) sql.Result NamedQuery(string, interface{}) (*Rows, error) } var _ Queryable = (*DB)(nil) var _ Queryable = (*Tx)(nil) // DB is a wrapper around sql.DB which keeps track of the driverName upon Open, // used mostly to automatically bind named queries using the right bindvars. type DB struct { *sql.DB driverName string unsafe bool Mapper *reflectx.Mapper } // NewDb returns a new sqlx DB wrapper for a pre-existing *sql.DB. The // driverName of the original database is required for named query support. // //lint:ignore ST1003 changing this would break the package interface. func NewDb(db *sql.DB, driverName string) *DB { return &DB{DB: db, driverName: driverName, Mapper: mapper()} } // DriverName returns the driverName passed to the Open function for this DB. func (db *DB) DriverName() string { return db.driverName } // Open is the same as sql.Open, but returns an *sqlx.DB instead. func Open(driverName, dataSourceName string) (*DB, error) { db, err := sql.Open(driverName, dataSourceName) if err != nil { return nil, err } return &DB{DB: db, driverName: driverName, Mapper: mapper()}, err } // MustOpen is the same as sql.Open, but returns an *sqlx.DB instead and panics on error. func MustOpen(driverName, dataSourceName string) *DB { db, err := Open(driverName, dataSourceName) if err != nil { panic(err) } return db } // MapperFunc sets a new mapper for this db using the default sqlx struct tag // and the provided mapper function. func (db *DB) MapperFunc(mf func(string) string) { db.Mapper = reflectx.NewMapperFunc("db", mf) } // Rebind transforms a query from QUESTION to the DB driver's bindvar type. func (db *DB) Rebind(query string) string { return Rebind(BindType(db.driverName), query) } // Unsafe returns a version of DB which will silently succeed to scan when // columns in the SQL result have no fields in the destination struct. // sqlx.Stmt and sqlx.Tx which are created from this DB will inherit its // safety behavior. func (db *DB) Unsafe() *DB { return &DB{DB: db.DB, driverName: db.driverName, unsafe: true, Mapper: db.Mapper} } // BindNamed binds a query using the DB driver's bindvar type. func (db *DB) BindNamed(query string, arg interface{}) (string, []interface{}, error) { return bindNamedMapper(BindType(db.driverName), query, arg, db.Mapper) } // NamedQuery using this DB. // Any named placeholder parameters are replaced with fields from arg. func (db *DB) NamedQuery(query string, arg interface{}) (*Rows, error) { return NamedQuery(db, query, arg) } // NamedExec using this DB. // Any named placeholder parameters are replaced with fields from arg. func (db *DB) NamedExec(query string, arg interface{}) (sql.Result, error) { return NamedExec(db, query, arg) } // Select using this DB. // Any placeholder parameters are replaced with supplied args. func (db *DB) Select(dest interface{}, query string, args ...interface{}) error { return Select(db, dest, query, args...) } // Get using this DB. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty or contains more than one row. func (db *DB) Get(dest interface{}, query string, args ...interface{}) error { return Get(db, dest, query, args...) } // MustBegin starts a transaction, and panics on error. Returns an *sqlx.Tx instead // of an *sql.Tx. func (db *DB) MustBegin() *Tx { tx, err := db.Beginx() if err != nil { panic(err) } return tx } // Beginx begins a transaction and returns an *sqlx.Tx instead of an *sql.Tx. func (db *DB) Beginx() (*Tx, error) { tx, err := db.DB.Begin() if err != nil { return nil, err } return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err } // Queryx queries the database and returns an *sqlx.Rows. // Any placeholder parameters are replaced with supplied args. func (db *DB) Queryx(query string, args ...interface{}) (*Rows, error) { r, err := db.DB.Query(query, args...) if err != nil { return nil, err } return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err } // QueryRowx queries the database and returns an *sqlx.Row. // Any placeholder parameters are replaced with supplied args. func (db *DB) QueryRowx(query string, args ...interface{}) *Row { rows, err := db.DB.Query(query, args...) return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper} } // MustExec (panic) runs MustExec using this database. // Any placeholder parameters are replaced with supplied args. func (db *DB) MustExec(query string, args ...interface{}) sql.Result { return MustExec(db, query, args...) } // Preparex returns an sqlx.Stmt instead of a sql.Stmt func (db *DB) Preparex(query string) (*Stmt, error) { return Preparex(db, query) } // PrepareNamed returns an sqlx.NamedStmt func (db *DB) PrepareNamed(query string) (*NamedStmt, error) { return prepareNamed(db, query) } // Conn is a wrapper around sql.Conn with extra functionality type Conn struct { *sql.Conn driverName string unsafe bool Mapper *reflectx.Mapper } // Tx is an sqlx wrapper around sql.Tx with extra functionality type Tx struct { *sql.Tx driverName string unsafe bool Mapper *reflectx.Mapper } // DriverName returns the driverName used by the DB which began this transaction. func (tx *Tx) DriverName() string { return tx.driverName } // Rebind a query within a transaction's bindvar type. func (tx *Tx) Rebind(query string) string { return Rebind(BindType(tx.driverName), query) } // Unsafe returns a version of Tx which will silently succeed to scan when // columns in the SQL result have no fields in the destination struct. func (tx *Tx) Unsafe() *Tx { return &Tx{Tx: tx.Tx, driverName: tx.driverName, unsafe: true, Mapper: tx.Mapper} } // BindNamed binds a query within a transaction's bindvar type. func (tx *Tx) BindNamed(query string, arg interface{}) (string, []interface{}, error) { return bindNamedMapper(BindType(tx.driverName), query, arg, tx.Mapper) } // NamedQuery within a transaction. // Any named placeholder parameters are replaced with fields from arg. func (tx *Tx) NamedQuery(query string, arg interface{}) (*Rows, error) { return NamedQuery(tx, query, arg) } // NamedExec a named query within a transaction. // Any named placeholder parameters are replaced with fields from arg. func (tx *Tx) NamedExec(query string, arg interface{}) (sql.Result, error) { return NamedExec(tx, query, arg) } // Select within a transaction. // Any placeholder parameters are replaced with supplied args. func (tx *Tx) Select(dest interface{}, query string, args ...interface{}) error { return Select(tx, dest, query, args...) } // Queryx within a transaction. // Any placeholder parameters are replaced with supplied args. func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) { r, err := tx.Tx.Query(query, args...) if err != nil { return nil, err } return &Rows{Rows: r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err } // QueryRowx within a transaction. // Any placeholder parameters are replaced with supplied args. func (tx *Tx) QueryRowx(query string, args ...interface{}) *Row { rows, err := tx.Tx.Query(query, args...) return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper} } // Get within a transaction. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty or contains more than one row. func (tx *Tx) Get(dest interface{}, query string, args ...interface{}) error { return Get(tx, dest, query, args...) } // MustExec runs MustExec within a transaction. // Any placeholder parameters are replaced with supplied args. func (tx *Tx) MustExec(query string, args ...interface{}) sql.Result { return MustExec(tx, query, args...) } // Preparex a statement within a transaction. func (tx *Tx) Preparex(query string) (*Stmt, error) { return Preparex(tx, query) } // Stmtx returns a version of the prepared statement which runs within a transaction. Provided // stmt can be either *sql.Stmt or *sqlx.Stmt. func (tx *Tx) Stmtx(stmt interface{}) *Stmt { var s *sql.Stmt switch v := stmt.(type) { case Stmt: s = v.Stmt case *Stmt: s = v.Stmt case *sql.Stmt: s = v default: panic(fmt.Sprintf("non-statement type %v passed to Stmtx", reflect.ValueOf(stmt).Type())) } return &Stmt{Stmt: tx.Stmt(s), Mapper: tx.Mapper} } // NamedStmt returns a version of the prepared statement which runs within a transaction. func (tx *Tx) NamedStmt(stmt *NamedStmt) *NamedStmt { return &NamedStmt{ QueryString: stmt.QueryString, Params: stmt.Params, Stmt: tx.Stmtx(stmt.Stmt), } } // PrepareNamed returns an sqlx.NamedStmt func (tx *Tx) PrepareNamed(query string) (*NamedStmt, error) { return prepareNamed(tx, query) } // Stmt is an sqlx wrapper around sql.Stmt with extra functionality type Stmt struct { *sql.Stmt unsafe bool Mapper *reflectx.Mapper } // Unsafe returns a version of Stmt which will silently succeed to scan when // columns in the SQL result have no fields in the destination struct. func (s *Stmt) Unsafe() *Stmt { return &Stmt{Stmt: s.Stmt, unsafe: true, Mapper: s.Mapper} } // Select using the prepared statement. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) Select(dest interface{}, args ...interface{}) error { return Select(&qStmt{s}, dest, "", args...) } // Get using the prepared statement. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty or contains more than one row. func (s *Stmt) Get(dest interface{}, args ...interface{}) error { return Get(&qStmt{s}, dest, "", args...) } // MustExec (panic) using this statement. Note that the query portion of the error // output will be blank, as Stmt does not expose its query. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) MustExec(args ...interface{}) sql.Result { return MustExec(&qStmt{s}, "", args...) } // QueryRowx using this statement. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) QueryRowx(args ...interface{}) *Row { qs := &qStmt{s} return qs.QueryRowx("", args...) } // Queryx using this statement. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) Queryx(args ...interface{}) (*Rows, error) { qs := &qStmt{s} return qs.Queryx("", args...) } // qStmt is an unexposed wrapper which lets you use a Stmt as a Queryer & Execer by // implementing those interfaces and ignoring the `query` argument. type qStmt struct{ *Stmt } func (q *qStmt) Query(query string, args ...interface{}) (*sql.Rows, error) { return q.Stmt.Query(args...) } func (q *qStmt) Queryx(query string, args ...interface{}) (*Rows, error) { r, err := q.Stmt.Query(args...) if err != nil { return nil, err } return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err } func (q *qStmt) QueryRowx(query string, args ...interface{}) *Row { rows, err := q.Stmt.Query(args...) return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper} } func (q *qStmt) Exec(query string, args ...interface{}) (sql.Result, error) { return q.Stmt.Exec(args...) } // Rows is a wrapper around sql.Rows which caches costly reflect operations // during a looped StructScan type Rows struct { *sql.Rows unsafe bool Mapper *reflectx.Mapper // these fields cache memory use for a rows during iteration w/ structScan started bool fields [][]int values []interface{} } // SliceScan using this Rows. func (r *Rows) SliceScan() ([]interface{}, error) { return SliceScan(r) } // MapScan using this Rows. func (r *Rows) MapScan(dest map[string]interface{}) error { return MapScan(r, dest) } // StructScan is like sql.Rows.Scan, but scans a single Row into a single Struct. // Use this and iterate over Rows manually when the memory load of Select() might be // prohibitive. *Rows.StructScan caches the reflect work of matching up column // positions to fields to avoid that overhead per scan, which means it is not safe // to run StructScan on the same Rows instance with different struct types. func (r *Rows) StructScan(dest interface{}) error { v := reflect.ValueOf(dest) if v.Kind() != reflect.Ptr { return errors.New("must pass a pointer, not a value, to StructScan destination") } v = v.Elem() if !r.started { columns, err := r.Columns() if err != nil { return err } m := r.Mapper r.fields = m.TraversalsByName(v.Type(), columns) // if we are not unsafe and are missing fields, return an error if f, err := missingFields(r.fields); err != nil && !r.unsafe { return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } r.values = make([]interface{}, len(columns)) r.started = true } err := fieldsByTraversal(v, r.fields, r.values, true) if err != nil { return err } // scan into the struct field pointers and append to our results err = r.Scan(r.values...) if err != nil { return err } return r.Err() } // NextResultSet moves to the next resultset if available and resets the field cache. func (r *Rows) NextResultSet() bool { if !r.Rows.NextResultSet() { return false } // reset fields cache r.started = false return true } // Connect to a database and verify with a ping. func Connect(driverName, dataSourceName string) (*DB, error) { db, err := Open(driverName, dataSourceName) if err != nil { return nil, err } err = db.Ping() if err != nil { db.Close() return nil, err } return db, nil } // MustConnect connects to a database and panics on error. func MustConnect(driverName, dataSourceName string) *DB { db, err := Connect(driverName, dataSourceName) if err != nil { panic(err) } return db } // Preparex prepares a statement. func Preparex(p Preparer, query string) (*Stmt, error) { s, err := p.Prepare(query) if err != nil { return nil, err } return &Stmt{Stmt: s, unsafe: isUnsafe(p), Mapper: mapperFor(p)}, err } // Select executes a query using the provided Queryer, and StructScans each row // into dest, which must be a slice. If the slice elements are scannable, then // the result set must have only one column. Otherwise, StructScan is used. // The *sql.Rows are closed automatically. // Any placeholder parameters are replaced with supplied args. func Select(q Queryer, dest interface{}, query string, args ...interface{}) error { rows, err := q.Queryx(query, args...) if err != nil { return err } // if something happens here, we want to make sure the rows are Closed defer rows.Close() return scanAll(rows, dest, false) } // Get does a QueryRow using the provided Queryer, and scans the resulting row // to dest. If dest is scannable, the result must only have one column. Otherwise, // StructScan is used. Get will return sql.ErrNoRows like row.Scan would. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty or contains more than one row. func Get(q Queryer, dest interface{}, query string, args ...interface{}) error { r := q.QueryRowx(query, args...) return r.scanAny(dest, false) } // LoadFile exec's every statement in a file (as a single call to Exec). // LoadFile may return a nil *sql.Result if errors are encountered locating or // reading the file at path. LoadFile reads the entire file into memory, so it // is not suitable for loading large data dumps, but can be useful for initializing // schemas or loading indexes. // // FIXME: this does not really work with multi-statement files for mattn/go-sqlite3 // or the go-mysql-driver/mysql drivers; pq seems to be an exception here. Detecting // this by requiring something with DriverName() and then attempting to split the // queries will be difficult to get right, and its current driver-specific behavior // is deemed at least not complex in its incorrectness. func LoadFile(e Execer, path string) (*sql.Result, error) { realpath, err := filepath.Abs(path) if err != nil { return nil, err } contents, err := os.ReadFile(realpath) if err != nil { return nil, err } res, err := e.Exec(string(contents)) return &res, err } // MustExec execs the query using e and panics if there was an error. // Any placeholder parameters are replaced with supplied args. func MustExec(e Execer, query string, args ...interface{}) sql.Result { res, err := e.Exec(query, args...) if err != nil { panic(err) } return res } // SliceScan using this Rows. func (r *Row) SliceScan() ([]interface{}, error) { return SliceScan(r) } // MapScan using this Rows. func (r *Row) MapScan(dest map[string]interface{}) error { return MapScan(r, dest) } func (r *Row) scanAny(dest interface{}, structOnly bool) error { if r.err != nil { return r.err } if r.rows == nil { r.err = sql.ErrNoRows return r.err } defer r.rows.Close() v := reflect.ValueOf(dest) if v.Kind() != reflect.Ptr { return errors.New("must pass a pointer, not a value, to StructScan destination") } if v.IsNil() { return errors.New("nil pointer passed to StructScan destination") } base := reflectx.Deref(v.Type()) scannable := isScannable(base) if structOnly && scannable { return structOnlyError(base) } columns, err := r.Columns() if err != nil { return err } if scannable && len(columns) > 1 { return fmt.Errorf("scannable dest type %s with >1 columns (%d) in result", base.Kind(), len(columns)) } if scannable { return r.Scan(dest) } m := r.Mapper fields := m.TraversalsByName(v.Type(), columns) // if we are not unsafe and are missing fields, return an error if f, err := missingFields(fields); err != nil && !r.unsafe { return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } values := make([]interface{}, len(columns)) err = fieldsByTraversal(v, fields, values, true) if err != nil { return err } // scan into the struct field pointers and append to our results return r.Scan(values...) } // StructScan a single Row into dest. func (r *Row) StructScan(dest interface{}) error { return r.scanAny(dest, true) } // SliceScan a row, returning a []interface{} with values similar to MapScan. // This function is primarily intended for use where the number of columns // is not known. Because you can pass an []interface{} directly to Scan, // it's recommended that you do that as it will not have to allocate new // slices per row. func SliceScan(r ColScanner) ([]interface{}, error) { // ignore r.started, since we needn't use reflect for anything. columns, err := r.Columns() if err != nil { return []interface{}{}, err } values := make([]interface{}, len(columns)) for i := range values { values[i] = new(interface{}) } err = r.Scan(values...) if err != nil { return values, err } for i := range columns { values[i] = *(values[i].(*interface{})) } return values, r.Err() } // MapScan scans a single Row into the dest map[string]interface{}. // Use this to get results for SQL that might not be under your control // (for instance, if you're building an interface for an SQL server that // executes SQL from input). Please do not use this as a primary interface! // This will modify the map sent to it in place, so reuse the same map with // care. Columns which occur more than once in the result will overwrite // each other! func MapScan(r ColScanner, dest map[string]interface{}) error { // ignore r.started, since we needn't use reflect for anything. columns, err := r.Columns() if err != nil { return err } values := make([]interface{}, len(columns)) for i := range values { values[i] = new(interface{}) } err = r.Scan(values...) if err != nil { return err } for i, column := range columns { dest[column] = *(values[i].(*interface{})) } return r.Err() } type rowsi interface { Close() error Columns() ([]string, error) Err() error Next() bool Scan(...interface{}) error } // structOnlyError returns an error appropriate for type when a non-scannable // struct is expected but something else is given func structOnlyError(t reflect.Type) error { isStruct := t.Kind() == reflect.Struct isScanner := reflect.PointerTo(t).Implements(_scannerInterface) if !isStruct { return fmt.Errorf("expected %s but got %s", reflect.Struct, t.Kind()) } if isScanner { return fmt.Errorf("structscan expects a struct dest but the provided struct type %s implements scanner", t.Name()) } return fmt.Errorf("expected a struct, but struct %s has no exported fields", t.Name()) } // scanAll scans all rows into a destination, which must be a slice of any // type. It resets the slice length to zero before appending each element to // the slice. If the destination slice type is a Struct, then StructScan will // be used on each row. If the destination is some other kind of base type, // then each row must only have one column which can scan into that type. This // allows you to do something like: // // rows, _ := db.Query("select id from people;") // var ids []int // scanAll(rows, &ids, false) // // and ids will be a list of the id results. I realize that this is a desirable // interface to expose to users, but for now it will only be exposed via changes // to `Get` and `Select`. The reason that this has been implemented like this is // this is the only way to not duplicate reflect work in the new API while // maintaining backwards compatibility. func scanAll(rows rowsi, dest interface{}, structOnly bool) error { var v, vp reflect.Value value := reflect.ValueOf(dest) // json.Unmarshal returns errors for these if value.Kind() != reflect.Ptr { return errors.New("must pass a pointer, not a value, to StructScan destination") } if value.IsNil() { return errors.New("nil pointer passed to StructScan destination") } direct := reflect.Indirect(value) slice, err := baseType(value.Type(), reflect.Slice) if err != nil { return err } direct.SetLen(0) isPtr := slice.Elem().Kind() == reflect.Ptr base := reflectx.Deref(slice.Elem()) scannable := isScannable(base) if structOnly && scannable { return structOnlyError(base) } columns, err := rows.Columns() if err != nil { return err } // if it's a base type make sure it only has 1 column; if not return an error if scannable && len(columns) > 1 { return fmt.Errorf("non-struct dest type %s with >1 columns (%d)", base.Kind(), len(columns)) } if !scannable { var values []interface{} var m *reflectx.Mapper switch rows := rows.(type) { case *Rows: m = rows.Mapper default: m = mapper() } fields := m.TraversalsByName(base, columns) // if we are not unsafe and are missing fields, return an error if f, err := missingFields(fields); err != nil && !isUnsafe(rows) { return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } values = make([]interface{}, len(columns)) for rows.Next() { // create a new struct type (which returns PtrTo) and indirect it vp = reflect.New(base) v = reflect.Indirect(vp) err = fieldsByTraversal(v, fields, values, true) if err != nil { return err } // scan into the struct field pointers and append to our results err = rows.Scan(values...) if err != nil { return err } if isPtr { direct.Set(reflect.Append(direct, vp)) } else { direct.Set(reflect.Append(direct, v)) } } } else { for rows.Next() { vp = reflect.New(base) err = rows.Scan(vp.Interface()) if err != nil { return err } // append if isPtr { direct.Set(reflect.Append(direct, vp)) } else { direct.Set(reflect.Append(direct, reflect.Indirect(vp))) } } } return rows.Err() } // FIXME: StructScan was the very first bit of API in sqlx, and now unfortunately // it doesn't really feel like it's named properly. There is an incongruency // between this and the way that StructScan (which might better be ScanStruct // anyway) works on a rows object. // StructScan all rows from an sql.Rows or an sqlx.Rows into the dest slice. // StructScan will scan in the entire rows result, so if you do not want to // allocate structs for the entire result, use Queryx and see sqlx.Rows.StructScan. // If rows is sqlx.Rows, it will use its mapper, otherwise it will use the default. func StructScan(rows rowsi, dest interface{}) error { return scanAll(rows, dest, true) } // reflect helpers func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { t = reflectx.Deref(t) if t.Kind() != expected { return nil, fmt.Errorf("expected %s but got %s", expected, t.Kind()) } return t, nil } // fieldsByName fills a values interface with fields from the passed value based // on the traversals in int. If ptrs is true, return addresses instead of values. // We write this instead of using FieldsByName to save allocations and map lookups // when iterating over many rows. Empty traversals will get an interface pointer. // Because of the necessity of requesting ptrs or values, it's considered a bit too // specialized for inclusion in reflectx itself. func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { v = reflect.Indirect(v) if v.Kind() != reflect.Struct { return errors.New("argument not a struct") } for i, traversal := range traversals { if len(traversal) == 0 { values[i] = new(interface{}) continue } f := reflectx.FieldByIndexes(v, traversal) if ptrs { values[i] = f.Addr().Interface() } else { values[i] = f.Interface() } } return nil } func missingFields(transversals [][]int) (field int, err error) { for i, t := range transversals { if len(t) == 0 { return i, errors.New("missing field") } } return 0, nil }
package sqlx import ( "context" "database/sql" "fmt" "os" "path/filepath" "reflect" ) // ConnectContext to a database and verify with a ping. func ConnectContext(ctx context.Context, driverName, dataSourceName string) (*DB, error) { db, err := Open(driverName, dataSourceName) if err != nil { return db, err } err = db.PingContext(ctx) return db, err } // QueryerContext is an interface used by GetContext and SelectContext type QueryerContext interface { QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row } // PreparerContext is an interface used by PreparexContext. type PreparerContext interface { PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) } // ExecerContext is an interface used by MustExecContext and LoadFileContext type ExecerContext interface { ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) } // ExtContext is a union interface which can bind, query, and exec, with Context // used by NamedQueryContext and NamedExecContext. type ExtContext interface { binder QueryerContext ExecerContext } // SelectContext executes a query using the provided Queryer, and StructScans // each row into dest, which must be a slice. If the slice elements are // scannable, then the result set must have only one column. Otherwise, // StructScan is used. The *sql.Rows are closed automatically. // Any placeholder parameters are replaced with supplied args. func SelectContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error { rows, err := q.QueryxContext(ctx, query, args...) if err != nil { return err } // if something happens here, we want to make sure the rows are Closed defer rows.Close() return scanAll(rows, dest, false) } // PreparexContext prepares a statement. // // The provided context is used for the preparation of the statement, not for // the execution of the statement. func PreparexContext(ctx context.Context, p PreparerContext, query string) (*Stmt, error) { s, err := p.PrepareContext(ctx, query) if err != nil { return nil, err } return &Stmt{Stmt: s, unsafe: isUnsafe(p), Mapper: mapperFor(p)}, err } // GetContext does a QueryRow using the provided Queryer, and scans the // resulting row to dest. If dest is scannable, the result must only have one // column. Otherwise, StructScan is used. Get will return sql.ErrNoRows like // row.Scan would. Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. func GetContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error { r := q.QueryRowxContext(ctx, query, args...) return r.scanAny(dest, false) } // LoadFileContext exec's every statement in a file (as a single call to Exec). // LoadFileContext may return a nil *sql.Result if errors are encountered // locating or reading the file at path. LoadFile reads the entire file into // memory, so it is not suitable for loading large data dumps, but can be useful // for initializing schemas or loading indexes. // // FIXME: this does not really work with multi-statement files for mattn/go-sqlite3 // or the go-mysql-driver/mysql drivers; pq seems to be an exception here. Detecting // this by requiring something with DriverName() and then attempting to split the // queries will be difficult to get right, and its current driver-specific behavior // is deemed at least not complex in its incorrectness. func LoadFileContext(ctx context.Context, e ExecerContext, path string) (*sql.Result, error) { realpath, err := filepath.Abs(path) if err != nil { return nil, err } contents, err := os.ReadFile(realpath) if err != nil { return nil, err } res, err := e.ExecContext(ctx, string(contents)) return &res, err } // MustExecContext execs the query using e and panics if there was an error. // Any placeholder parameters are replaced with supplied args. func MustExecContext(ctx context.Context, e ExecerContext, query string, args ...interface{}) sql.Result { res, err := e.ExecContext(ctx, query, args...) if err != nil { panic(err) } return res } // PrepareNamedContext returns an sqlx.NamedStmt func (db *DB) PrepareNamedContext(ctx context.Context, query string) (*NamedStmt, error) { return prepareNamedContext(ctx, db, query) } // NamedQueryContext using this DB. // Any named placeholder parameters are replaced with fields from arg. func (db *DB) NamedQueryContext(ctx context.Context, query string, arg interface{}) (*Rows, error) { return NamedQueryContext(ctx, db, query, arg) } // NamedExecContext using this DB. // Any named placeholder parameters are replaced with fields from arg. func (db *DB) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) { return NamedExecContext(ctx, db, query, arg) } // SelectContext using this DB. // Any placeholder parameters are replaced with supplied args. func (db *DB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { return SelectContext(ctx, db, dest, query, args...) } // GetContext using this DB. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. func (db *DB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { return GetContext(ctx, db, dest, query, args...) } // PreparexContext returns an sqlx.Stmt instead of a sql.Stmt. // // The provided context is used for the preparation of the statement, not for // the execution of the statement. func (db *DB) PreparexContext(ctx context.Context, query string) (*Stmt, error) { return PreparexContext(ctx, db, query) } // QueryxContext queries the database and returns an *sqlx.Rows. // Any placeholder parameters are replaced with supplied args. func (db *DB) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { r, err := db.DB.QueryContext(ctx, query, args...) if err != nil { return nil, err } return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err } // QueryRowxContext queries the database and returns an *sqlx.Row. // Any placeholder parameters are replaced with supplied args. func (db *DB) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { rows, err := db.DB.QueryContext(ctx, query, args...) return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper} } // MustBeginTx starts a transaction, and panics on error. Returns an *sqlx.Tx instead // of an *sql.Tx. // // The provided context is used until the transaction is committed or rolled // back. If the context is canceled, the sql package will roll back the // transaction. Tx.Commit will return an error if the context provided to // MustBeginContext is canceled. func (db *DB) MustBeginTx(ctx context.Context, opts *sql.TxOptions) *Tx { tx, err := db.BeginTxx(ctx, opts) if err != nil { panic(err) } return tx } // MustExecContext (panic) runs MustExec using this database. // Any placeholder parameters are replaced with supplied args. func (db *DB) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result { return MustExecContext(ctx, db, query, args...) } // BeginTxx begins a transaction and returns an *sqlx.Tx instead of an // *sql.Tx. // // The provided context is used until the transaction is committed or rolled // back. If the context is canceled, the sql package will roll back the // transaction. Tx.Commit will return an error if the context provided to // BeginxContext is canceled. func (db *DB) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { tx, err := db.DB.BeginTx(ctx, opts) if err != nil { return nil, err } return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err } // Connx returns an *sqlx.Conn instead of an *sql.Conn. func (db *DB) Connx(ctx context.Context) (*Conn, error) { conn, err := db.DB.Conn(ctx) if err != nil { return nil, err } return &Conn{Conn: conn, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, nil } // BeginTxx begins a transaction and returns an *sqlx.Tx instead of an // *sql.Tx. // // The provided context is used until the transaction is committed or rolled // back. If the context is canceled, the sql package will roll back the // transaction. Tx.Commit will return an error if the context provided to // BeginxContext is canceled. func (c *Conn) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { tx, err := c.Conn.BeginTx(ctx, opts) if err != nil { return nil, err } return &Tx{Tx: tx, driverName: c.driverName, unsafe: c.unsafe, Mapper: c.Mapper}, err } // SelectContext using this Conn. // Any placeholder parameters are replaced with supplied args. func (c *Conn) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { return SelectContext(ctx, c, dest, query, args...) } // GetContext using this Conn. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. func (c *Conn) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { return GetContext(ctx, c, dest, query, args...) } // PreparexContext returns an sqlx.Stmt instead of a sql.Stmt. // // The provided context is used for the preparation of the statement, not for // the execution of the statement. func (c *Conn) PreparexContext(ctx context.Context, query string) (*Stmt, error) { return PreparexContext(ctx, c, query) } // QueryxContext queries the database and returns an *sqlx.Rows. // Any placeholder parameters are replaced with supplied args. func (c *Conn) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { r, err := c.Conn.QueryContext(ctx, query, args...) if err != nil { return nil, err } return &Rows{Rows: r, unsafe: c.unsafe, Mapper: c.Mapper}, err } // QueryRowxContext queries the database and returns an *sqlx.Row. // Any placeholder parameters are replaced with supplied args. func (c *Conn) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { rows, err := c.Conn.QueryContext(ctx, query, args...) return &Row{rows: rows, err: err, unsafe: c.unsafe, Mapper: c.Mapper} } // Rebind a query within a Conn's bindvar type. func (c *Conn) Rebind(query string) string { return Rebind(BindType(c.driverName), query) } // StmtxContext returns a version of the prepared statement which runs within a // transaction. Provided stmt can be either *sql.Stmt or *sqlx.Stmt. func (tx *Tx) StmtxContext(ctx context.Context, stmt interface{}) *Stmt { var s *sql.Stmt switch v := stmt.(type) { case Stmt: s = v.Stmt case *Stmt: s = v.Stmt case *sql.Stmt: s = v default: panic(fmt.Sprintf("non-statement type %v passed to Stmtx", reflect.ValueOf(stmt).Type())) } return &Stmt{Stmt: tx.StmtContext(ctx, s), Mapper: tx.Mapper} } // NamedStmtContext returns a version of the prepared statement which runs // within a transaction. func (tx *Tx) NamedStmtContext(ctx context.Context, stmt *NamedStmt) *NamedStmt { return &NamedStmt{ QueryString: stmt.QueryString, Params: stmt.Params, Stmt: tx.StmtxContext(ctx, stmt.Stmt), } } // PreparexContext returns an sqlx.Stmt instead of a sql.Stmt. // // The provided context is used for the preparation of the statement, not for // the execution of the statement. func (tx *Tx) PreparexContext(ctx context.Context, query string) (*Stmt, error) { return PreparexContext(ctx, tx, query) } // PrepareNamedContext returns an sqlx.NamedStmt func (tx *Tx) PrepareNamedContext(ctx context.Context, query string) (*NamedStmt, error) { return prepareNamedContext(ctx, tx, query) } // MustExecContext runs MustExecContext within a transaction. // Any placeholder parameters are replaced with supplied args. func (tx *Tx) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result { return MustExecContext(ctx, tx, query, args...) } // QueryxContext within a transaction and context. // Any placeholder parameters are replaced with supplied args. func (tx *Tx) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { r, err := tx.Tx.QueryContext(ctx, query, args...) if err != nil { return nil, err } return &Rows{Rows: r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err } // SelectContext within a transaction and context. // Any placeholder parameters are replaced with supplied args. func (tx *Tx) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { return SelectContext(ctx, tx, dest, query, args...) } // GetContext within a transaction and context. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. func (tx *Tx) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { return GetContext(ctx, tx, dest, query, args...) } // QueryRowxContext within a transaction and context. // Any placeholder parameters are replaced with supplied args. func (tx *Tx) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { rows, err := tx.Tx.QueryContext(ctx, query, args...) return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper} } // NamedExecContext using this Tx. // Any named placeholder parameters are replaced with fields from arg. func (tx *Tx) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) { return NamedExecContext(ctx, tx, query, arg) } // SelectContext using the prepared statement. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) SelectContext(ctx context.Context, dest interface{}, args ...interface{}) error { return SelectContext(ctx, &qStmt{s}, dest, "", args...) } // GetContext using the prepared statement. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. func (s *Stmt) GetContext(ctx context.Context, dest interface{}, args ...interface{}) error { return GetContext(ctx, &qStmt{s}, dest, "", args...) } // MustExecContext (panic) using this statement. Note that the query portion of // the error output will be blank, as Stmt does not expose its query. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) MustExecContext(ctx context.Context, args ...interface{}) sql.Result { return MustExecContext(ctx, &qStmt{s}, "", args...) } // QueryRowxContext using this statement. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) QueryRowxContext(ctx context.Context, args ...interface{}) *Row { qs := &qStmt{s} return qs.QueryRowxContext(ctx, "", args...) } // QueryxContext using this statement. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) QueryxContext(ctx context.Context, args ...interface{}) (*Rows, error) { qs := &qStmt{s} return qs.QueryxContext(ctx, "", args...) } func (q *qStmt) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { return q.Stmt.QueryContext(ctx, args...) } func (q *qStmt) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { r, err := q.Stmt.QueryContext(ctx, args...) if err != nil { return nil, err } return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err } func (q *qStmt) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { rows, err := q.Stmt.QueryContext(ctx, args...) return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper} } func (q *qStmt) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { return q.Stmt.ExecContext(ctx, args...) }
package types import ( "bytes" "compress/gzip" "database/sql/driver" "encoding/json" "errors" "io" ) // GzippedText is a []byte which transparently gzips data being submitted to // a database and ungzips data being Scanned from a database. type GzippedText []byte // Value implements the driver.Valuer interface, gzipping the raw value of // this GzippedText. func (g GzippedText) Value() (driver.Value, error) { b := make([]byte, 0, len(g)) buf := bytes.NewBuffer(b) w := gzip.NewWriter(buf) w.Write(g) w.Close() return buf.Bytes(), nil } // Scan implements the sql.Scanner interface, ungzipping the value coming off // the wire and storing the raw result in the GzippedText. func (g *GzippedText) Scan(src interface{}) error { var source []byte switch src := src.(type) { case string: source = []byte(src) case []byte: source = src default: //lint:ignore ST1005 changing this could break consumers of this package return errors.New("Incompatible type for GzippedText") } reader, err := gzip.NewReader(bytes.NewReader(source)) if err != nil { return err } defer reader.Close() b, err := io.ReadAll(reader) if err != nil { return err } *g = GzippedText(b) return nil } // JSONText is a json.RawMessage, which is a []byte underneath. // Value() validates the json format in the source, and returns an error if // the json is not valid. Scan does no validation. JSONText additionally // implements `Unmarshal`, which unmarshals the json within to an interface{} type JSONText json.RawMessage var emptyJSON = JSONText("{}") // MarshalJSON returns the *j as the JSON encoding of j. func (j JSONText) MarshalJSON() ([]byte, error) { if len(j) == 0 { return emptyJSON, nil } return j, nil } // UnmarshalJSON sets *j to a copy of data func (j *JSONText) UnmarshalJSON(data []byte) error { if j == nil { return errors.New("JSONText: UnmarshalJSON on nil pointer") } *j = append((*j)[0:0], data...) return nil } // Value returns j as a value. This does a validating unmarshal into another // RawMessage. If j is invalid json, it returns an error. func (j JSONText) Value() (driver.Value, error) { var m json.RawMessage var err = j.Unmarshal(&m) if err != nil { return []byte{}, err } return []byte(j), nil } // Scan stores the src in *j. No validation is done. func (j *JSONText) Scan(src interface{}) error { var source []byte switch t := src.(type) { case string: source = []byte(t) case []byte: if len(t) == 0 { source = emptyJSON } else { source = t } case nil: *j = emptyJSON default: //lint:ignore ST1005 changing this could break consumers of this package return errors.New("Incompatible type for JSONText") } *j = append((*j)[0:0], source...) return nil } // Unmarshal unmarshal's the json in j to v, as in json.Unmarshal. func (j *JSONText) Unmarshal(v interface{}) error { if len(*j) == 0 { *j = emptyJSON } return json.Unmarshal([]byte(*j), v) } // String supports pretty printing for JSONText types. func (j JSONText) String() string { return string(j) } // NullJSONText represents a JSONText that may be null. // NullJSONText implements the scanner interface so // it can be used as a scan destination, similar to NullString. type NullJSONText struct { JSONText Valid bool // Valid is true if JSONText is not NULL } // Scan implements the Scanner interface. func (n *NullJSONText) Scan(value interface{}) error { if value == nil { n.JSONText, n.Valid = emptyJSON, false return nil } n.Valid = true return n.JSONText.Scan(value) } // Value implements the driver Valuer interface. func (n NullJSONText) Value() (driver.Value, error) { if !n.Valid { return nil, nil } return n.JSONText.Value() } // BitBool is an implementation of a bool for the MySQL type BIT(1). // This type allows you to avoid wasting an entire byte for MySQL's boolean type TINYINT. type BitBool bool // Value implements the driver.Valuer interface, // and turns the BitBool into a bitfield (BIT(1)) for MySQL storage. func (b BitBool) Value() (driver.Value, error) { if b { return []byte{1}, nil } return []byte{0}, nil } // Scan implements the sql.Scanner interface, // and turns the bitfield incoming from MySQL into a BitBool func (b *BitBool) Scan(src interface{}) error { v, ok := src.([]byte) if !ok { return errors.New("bad []byte type assertion") } *b = v[0] == 1 return nil }