package squirrel
import (
"bytes"
"errors"
"fmt"
"reflect"
"time"
"github.com/lann/builder"
)
func init() {
builder.Register(CaseBuilder{}, caseData{})
}
// sqlizerBuffer is a helper that allows to write many Sqlizers one by one
// without constant checks for errors that may come from Sqlizer
type sqlizerBuffer struct {
bytes.Buffer
args []any
err error
}
// WriteSql converts Sqlizer to SQL strings and writes it to buffer
func (b *sqlizerBuffer) WriteSql(item Sqlizer) {
if b.err != nil {
return
}
var str string
var args []any
str, args, b.err = nestedToSql(item)
if b.err != nil {
return
}
_, _ = b.WriteString(str)
_ = b.WriteByte(' ')
b.args = append(b.args, args...)
}
func (b *sqlizerBuffer) ToSql() (string, []any, error) {
return b.String(), b.args, b.err
}
// whenPart is a helper structure to describe SQLs "WHEN ... THEN ..." expression
type whenPart struct {
when Sqlizer
then Sqlizer
thenValue any
nullThen bool
}
func newWhenPart(when any, then any) whenPart {
wp := whenPart{
when: newPart(when),
}
switch t := then.(type) {
case Sqlizer:
wp.then = newPart(then)
default:
if t == nil {
wp.nullThen = true
} else {
sqlName, err := sqlTypeNameHelper(reflect.TypeOf(then))
if err != nil {
wp.thenValue = t
} else {
wp.then = newPart(Expr(fmt.Sprintf("CAST(? AS %s)", sqlName), t))
}
}
}
return wp
}
func sqlTypeNameHelper(t reflect.Type) (string, error) {
switch t.Kind() { //nolint:exhaustive
case reflect.Bool:
return "boolean", nil
case reflect.Int64, reflect.Uint64, reflect.Int, reflect.Uint:
return "bigint", nil
case reflect.Int32, reflect.Uint32:
return "integer", nil
case reflect.Int16, reflect.Uint16, reflect.Int8, reflect.Uint8:
return "smallint", nil
case reflect.Float32, reflect.Float64:
return "double precision", nil
case reflect.String:
return "text", nil
case reflect.Struct:
if t == reflect.TypeOf(time.Time{}) {
return "timestamp with time zone", nil
}
case reflect.Slice, reflect.Array:
sqlType, err := sqlTypeNameHelper(t.Elem())
if err != nil {
return "", err
}
return sqlType + "[]", nil
}
return "", fmt.Errorf("unsupported type %s", t.Name())
}
// caseData holds all the data required to build a CASE SQL construct
type caseData struct {
What Sqlizer
WhenParts []whenPart
Else Sqlizer
ElseValue any
ElseNull bool
}
// ToSql implements Sqlizer
func (d *caseData) ToSql() (sqlStr string, args []any, err error) {
if len(d.WhenParts) == 0 {
return "", nil, errors.New("case expression must contain at lease one WHEN clause")
}
sql := sqlizerBuffer{}
sql.WriteString("CASE ")
if d.What != nil {
sql.WriteSql(d.What)
}
for _, p := range d.WhenParts {
sql.WriteString("WHEN ")
sql.WriteSql(p.when)
if p.then == nil && p.thenValue == nil && !p.nullThen {
return "", nil, errors.New("When clause must have Then part")
}
sql.WriteString("THEN ")
if p.then != nil {
sql.WriteSql(p.then)
} else {
sql.WriteString(Placeholders(1) + " ")
sql.args = append(sql.args, p.thenValue)
}
}
if d.Else != nil || d.ElseValue != nil || d.ElseNull {
sql.WriteString("ELSE ")
}
if d.Else != nil {
sql.WriteSql(d.Else)
} else if d.ElseValue != nil || d.ElseNull {
sql.WriteString(Placeholders(1) + " ")
sql.args = append(sql.args, d.ElseValue)
}
sql.WriteString("END")
return sql.ToSql()
}
// CaseBuilder builds SQL CASE construct which could be used as parts of queries.
type CaseBuilder builder.Builder
// ToSql builds the query into a SQL string and bound args.
func (b CaseBuilder) ToSql() (string, []any, error) {
data := builder.GetStruct(b).(caseData)
return data.ToSql()
}
// MustSql builds the query into a SQL string and bound args.
// It panics if there are any errors.
func (b CaseBuilder) MustSql() (string, []any) {
sql, args, err := b.ToSql()
if err != nil {
panic(err)
}
return sql, args
}
// what sets optional value for CASE construct "CASE [value] ..."
func (b CaseBuilder) what(e any) CaseBuilder {
return builder.Set(b, "What", newPart(e)).(CaseBuilder)
}
// When adds "WHEN ... THEN ..." part to CASE construct
func (b CaseBuilder) When(when any, then any) CaseBuilder {
// TODO: performance hint: replace slice of WhenPart with just slice of parts
// where even indices of the slice belong to "when"s and odd indices belong to "then"s
return builder.Append(b, "WhenParts", newWhenPart(when, then)).(CaseBuilder)
}
// Else What sets optional "ELSE ..." part for CASE construct
func (b CaseBuilder) Else(e any) CaseBuilder {
switch e.(type) {
case Sqlizer:
return builder.Set(b, "Else", newPart(e)).(CaseBuilder)
default:
if e == nil {
return builder.Set(b, "ElseNull", true).(CaseBuilder)
}
return builder.Set(b, "ElseValue", e).(CaseBuilder)
}
}
package squirrel
import (
"bytes"
"fmt"
"github.com/lann/builder"
)
// Common Table Expressions helper
// e.g.
// WITH cte AS (
// ...
// ), cte_2 AS (
// ...
// )
// SELECT ... FROM cte ... cte_2;
type commonTableExpressionsData struct {
PlaceholderFormat PlaceholderFormat
Recursive bool
CurrentCteName string
Ctes []Sqlizer
Statement Sqlizer
}
func (d *commonTableExpressionsData) toSql() (sqlStr string, args []any, err error) {
if len(d.Ctes) == 0 {
err = fmt.Errorf("common table expressions statements must have at least one label and subquery")
return "", nil, err
}
if d.Statement == nil {
err = fmt.Errorf("common table expressions must one of the following final statement: (select, insert, replace, update, delete)")
return "", nil, err
}
sql := &bytes.Buffer{}
_, _ = sql.WriteString("WITH ")
if d.Recursive {
_, _ = sql.WriteString("RECURSIVE ")
}
args, err = appendToSql(d.Ctes, sql, ", ", args)
if err != nil {
return "", nil, err
}
_, _ = sql.WriteString(" ")
args, err = appendToSql([]Sqlizer{d.Statement}, sql, "", args)
if err != nil {
return "", nil, err
}
sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sql.String())
return sqlStr, args, err
}
func (d *commonTableExpressionsData) ToSql() (sql string, args []any, err error) {
return d.toSql()
}
// Builder
// CommonTableExpressionsBuilder builds CTE (Common Table Expressions) SQL statements.
type CommonTableExpressionsBuilder builder.Builder
func init() {
builder.Register(CommonTableExpressionsBuilder{}, commonTableExpressionsData{})
}
// Format methods
// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the
// query.
func (b CommonTableExpressionsBuilder) PlaceholderFormat(f PlaceholderFormat) CommonTableExpressionsBuilder {
return builder.Set(b, "PlaceholderFormat", f).(CommonTableExpressionsBuilder)
}
// SQL methods
// ToSql builds the query into a SQL string and bound args.
func (b CommonTableExpressionsBuilder) ToSql() (string, []any, error) {
data := builder.GetStruct(b).(commonTableExpressionsData)
return data.ToSql()
}
// MustSql builds the query into a SQL string and bound args.
// It panics if there are any errors.
func (b CommonTableExpressionsBuilder) MustSql() (string, []any) {
sql, args, err := b.ToSql()
if err != nil {
panic(err)
}
return sql, args
}
func (b CommonTableExpressionsBuilder) Recursive(recursive bool) CommonTableExpressionsBuilder {
return builder.Set(b, "Recursive", recursive).(CommonTableExpressionsBuilder)
}
// Cte starts a new cte
func (b CommonTableExpressionsBuilder) Cte(cte string) CommonTableExpressionsBuilder {
return builder.Set(b, "CurrentCteName", cte).(CommonTableExpressionsBuilder)
}
// As sets the expression for the Cte
func (b CommonTableExpressionsBuilder) As(as SelectBuilder) CommonTableExpressionsBuilder {
data := builder.GetStruct(b).(commonTableExpressionsData)
return builder.Append(b, "Ctes", cteExpr{as, data.CurrentCteName}).(CommonTableExpressionsBuilder)
}
// Select finalizes the CommonTableExpressionsBuilder with a SELECT
func (b CommonTableExpressionsBuilder) Select(statement SelectBuilder) CommonTableExpressionsBuilder {
return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder)
}
// Insert finalizes the CommonTableExpressionsBuilder with an INSERT
func (b CommonTableExpressionsBuilder) Insert(statement InsertBuilder) CommonTableExpressionsBuilder {
return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder)
}
// Replace finalizes the CommonTableExpressionsBuilder with a REPLACE
func (b CommonTableExpressionsBuilder) Replace(statement InsertBuilder) CommonTableExpressionsBuilder {
return b.Insert(statement)
}
// Update finalizes the CommonTableExpressionsBuilder with an UPDATE
func (b CommonTableExpressionsBuilder) Update(statement UpdateBuilder) CommonTableExpressionsBuilder {
return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder)
}
// Delete finalizes the CommonTableExpressionsBuilder with a DELETE
func (b CommonTableExpressionsBuilder) Delete(statement DeleteBuilder) CommonTableExpressionsBuilder {
return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder)
}
package squirrel
import (
"bytes"
"fmt"
"strings"
"github.com/lann/builder"
)
type deleteData struct {
PlaceholderFormat PlaceholderFormat
Prefixes []Sqlizer
From string
WhereParts []Sqlizer
OrderBys []string
Limit string
Offset string
Suffixes []Sqlizer
}
func (d *deleteData) toSqlRaw() (sqlStr string, args []any, err error) {
if len(d.From) == 0 {
err = fmt.Errorf("delete statements must specify a From table")
return "", nil, err
}
sql := &bytes.Buffer{}
if len(d.Prefixes) > 0 {
args, err = appendToSql(d.Prefixes, sql, " ", args)
if err != nil {
return "", nil, err
}
_, _ = sql.WriteString(" ")
}
_, _ = sql.WriteString("DELETE FROM ")
_, _ = sql.WriteString(d.From)
if len(d.WhereParts) > 0 {
_, _ = sql.WriteString(" WHERE ")
args, err = appendToSql(d.WhereParts, sql, " AND ", args)
if err != nil {
return "", nil, err
}
}
if len(d.OrderBys) > 0 {
_, _ = sql.WriteString(" ORDER BY ")
_, _ = sql.WriteString(strings.Join(d.OrderBys, ", "))
}
if len(d.Limit) > 0 {
_, _ = sql.WriteString(" LIMIT ")
_, _ = sql.WriteString(d.Limit)
}
if len(d.Offset) > 0 {
_, _ = sql.WriteString(" OFFSET ")
_, _ = sql.WriteString(d.Offset)
}
if len(d.Suffixes) > 0 {
_, _ = sql.WriteString(" ")
args, err = appendToSql(d.Suffixes, sql, " ", args)
if err != nil {
return "", nil, err
}
}
return sql.String(), args, nil
}
func (d *deleteData) ToSql() (sqlStr string, args []any, err error) {
s, a, e := d.toSqlRaw()
if e != nil {
return "", nil, e
}
sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(s)
return sqlStr, a, err
}
// Builder
// DeleteBuilder builds SQL DELETE statements.
type DeleteBuilder builder.Builder
func init() {
builder.Register(DeleteBuilder{}, deleteData{})
}
// Format methods
// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the
// query.
func (b DeleteBuilder) PlaceholderFormat(f PlaceholderFormat) DeleteBuilder {
return builder.Set(b, "PlaceholderFormat", f).(DeleteBuilder)
}
// SQL methods
// ToSql builds the query into a SQL string and bound args.
func (b DeleteBuilder) ToSql() (string, []any, error) {
data := builder.GetStruct(b).(deleteData)
return data.ToSql()
}
// MustSql builds the query into a SQL string and bound args.
// It panics if there are any errors.
func (b DeleteBuilder) MustSql() (string, []any) {
sql, args, err := b.ToSql()
if err != nil {
panic(err)
}
return sql, args
}
// Prefix adds an expression to the beginning of the query
func (b DeleteBuilder) Prefix(sql string, args ...any) DeleteBuilder {
return b.PrefixExpr(Expr(sql, args...))
}
// PrefixExpr adds an expression to the very beginning of the query
func (b DeleteBuilder) PrefixExpr(e Sqlizer) DeleteBuilder {
return builder.Append(b, "Prefixes", e).(DeleteBuilder)
}
// From sets the table to be deleted from.
func (b DeleteBuilder) From(from string) DeleteBuilder {
return builder.Set(b, "From", from).(DeleteBuilder)
}
// Where adds WHERE expressions to the query.
//
// See SelectBuilder.Where for more information.
func (b DeleteBuilder) Where(pred any, args ...any) DeleteBuilder {
return builder.Append(b, "WhereParts", newWherePart(pred, args...)).(DeleteBuilder)
}
// OrderBy adds ORDER BY expressions to the query.
func (b DeleteBuilder) OrderBy(orderBys ...string) DeleteBuilder {
return builder.Extend(b, "OrderBys", orderBys).(DeleteBuilder)
}
// Limit sets a LIMIT clause on the query.
func (b DeleteBuilder) Limit(limit uint64) DeleteBuilder {
return builder.Set(b, "Limit", fmt.Sprintf("%d", limit)).(DeleteBuilder)
}
// Offset sets a OFFSET clause on the query.
func (b DeleteBuilder) Offset(offset uint64) DeleteBuilder {
return builder.Set(b, "Offset", fmt.Sprintf("%d", offset)).(DeleteBuilder)
}
// toSqlRaw builds SQL with raw placeholders ("?") without applying PlaceholderFormat.
func (b DeleteBuilder) toSqlRaw() (string, []any, error) {
data := builder.GetStruct(b).(deleteData)
return data.toSqlRaw()
}
// Suffix adds an expression to the end of the query
func (b DeleteBuilder) Suffix(sql string, args ...any) DeleteBuilder {
return b.SuffixExpr(Expr(sql, args...))
}
// SuffixExpr adds an expression to the end of the query
func (b DeleteBuilder) SuffixExpr(e Sqlizer) DeleteBuilder {
return builder.Append(b, "Suffixes", e).(DeleteBuilder)
}
package squirrel
import (
"bytes"
"database/sql/driver"
"fmt"
"reflect"
"sort"
"strings"
)
const (
// Portable true/false literals.
sqlTrue = "(1=1)"
sqlFalse = "(1=0)"
)
type expr struct {
sql string
args []any
}
// Expr builds an expression from a SQL fragment and arguments.
//
// Ex:
//
// Expr("FROM_UNIXTIME(?)", t)
func Expr(sql string, args ...any) Sqlizer {
return expr{sql: sql, args: args}
}
func (e expr) ToSql() (sql string, args []any, err error) {
simple := true
for _, arg := range e.args {
if _, ok := arg.(Sqlizer); ok {
simple = false
}
if isListType(arg) {
simple = false
}
}
if simple {
return e.sql, e.args, nil
}
buf := &bytes.Buffer{}
ap := e.args
sp := e.sql
var isql string
var iargs []any
for err == nil && len(ap) > 0 && len(sp) > 0 {
i := strings.Index(sp, "?")
if i < 0 {
// no more placeholders
break
}
if len(sp) > i+1 && sp[i+1:i+2] == "?" {
// escaped "??"; append it and step past
buf.WriteString(sp[:i+2])
sp = sp[i+2:]
continue
}
if as, ok := ap[0].(Sqlizer); ok {
// sqlizer argument; expand it and append the result
isql, iargs, err = nestedToSql(as)
buf.WriteString(sp[:i])
buf.WriteString(isql)
args = append(args, iargs...)
} else {
// normal argument; append it and the placeholder
buf.WriteString(sp[:i+1])
args = append(args, ap[0])
}
// step past the argument and placeholder
ap = ap[1:]
sp = sp[i+1:]
}
// append the remaining sql and arguments
buf.WriteString(sp)
return buf.String(), append(args, ap...), err
}
type concatExpr []any
func (ce concatExpr) ToSql() (sql string, args []any, err error) {
for _, part := range ce {
switch p := part.(type) {
case string:
sql += p
case Sqlizer:
pSql, pArgs, err := nestedToSql(p)
if err != nil {
return "", nil, err
}
sql += pSql
args = append(args, pArgs...)
default:
return "", nil, fmt.Errorf("%#v is not a string or Sqlizer", part)
}
}
return
}
// ConcatExpr builds an expression by concatenating strings and other expressions.
//
// Ex:
//
// name_expr := Expr("CONCAT(?, ' ', ?)", firstName, lastName)
// ConcatExpr("COALESCE(full_name,", name_expr, ")")
func ConcatExpr(parts ...any) concatExpr {
return concatExpr(parts)
}
// aliasExpr helps to alias part of SQL query generated with underlying "expr"
type aliasExpr struct {
expr Sqlizer
alias string
}
// Alias allows to define alias for column in SelectBuilder. Useful when column is
// defined as complex expression like IF or CASE
// Ex:
//
// .Column(Alias(caseStmt, "case_column"))
func Alias(e Sqlizer, a string) aliasExpr {
return aliasExpr{e, a}
}
func (e aliasExpr) ToSql() (sql string, args []any, err error) {
sql, args, err = nestedToSql(e.expr)
if err == nil {
sql = fmt.Sprintf("(%s) AS %s", sql, e.alias)
}
return
}
// Eq is syntactic sugar for use with Where/Having/Set methods.
type Eq map[string]any
func (eq Eq) toSQL(useNotOpr bool) (sql string, args []any, err error) {
if len(eq) == 0 {
// Empty Sql{} evaluates to true.
sql = sqlTrue
return sql, args, nil
}
var (
exprs = make([]string, 0, len(eq))
equalOpr = "="
inOpr = "IN"
nullOpr = "IS"
inEmptyExpr = sqlFalse
)
if useNotOpr {
equalOpr = "<>"
inOpr = "NOT IN"
nullOpr = "IS NOT"
inEmptyExpr = sqlTrue
}
sortedKeys := getSortedKeys(eq)
for _, key := range sortedKeys {
var expr1 string
val := eq[key]
switch v := val.(type) {
case driver.Valuer:
if val, err = v.Value(); err != nil {
return "", nil, err
}
}
r := reflect.ValueOf(val)
if r.Kind() == reflect.Ptr {
if r.IsNil() {
val = nil
} else {
val = r.Elem().Interface()
}
}
if val == nil {
expr1 = fmt.Sprintf("%s %s NULL", key, nullOpr)
} else {
if isListType(val) {
valVal := reflect.ValueOf(val)
if valVal.Len() == 0 {
expr1 = inEmptyExpr
if args == nil {
args = []any{}
}
} else {
for i := 0; i < valVal.Len(); i++ {
args = append(args, valVal.Index(i).Interface())
}
expr1 = fmt.Sprintf("%s %s (%s)", key, inOpr, Placeholders(valVal.Len()))
}
} else if sb, ok := val.(SelectBuilder); ok {
var (
subSql string
subArgs []any
)
subSql, subArgs, err = sb.toSqlRaw()
if err != nil {
return "", nil, err
}
expr1 = fmt.Sprintf("%s %s (%s)", key, inOpr, subSql)
args = append(args, subArgs...)
} else {
expr1 = fmt.Sprintf("%s %s ?", key, equalOpr)
args = append(args, val)
}
}
exprs = append(exprs, expr1)
}
sql = strings.Join(exprs, " AND ")
return sql, args, nil
}
func (eq Eq) ToSql() (sql string, args []any, err error) {
return eq.toSQL(false)
}
// NotEq is syntactic sugar for use with Where/Having/Set methods.
// Ex:
//
// .Where(NotEq{"id": 1}) == "id <> 1"
type NotEq Eq
func (neq NotEq) ToSql() (sql string, args []any, err error) {
return Eq(neq).toSQL(true)
}
// Like is syntactic sugar for use with LIKE conditions.
// Ex:
//
// .Where(Like{"name": "%irrel"})
type Like map[string]any
func (lk Like) toSql(opr string) (sql string, args []any, err error) {
exprs := make([]string, 0, len(lk))
for key, val := range lk {
var expr1 string
switch v := val.(type) {
case driver.Valuer:
if val, err = v.Value(); err != nil {
return
}
}
if val == nil {
err = fmt.Errorf("cannot use null with like operators")
return
} else {
if isListType(val) {
err = fmt.Errorf("cannot use array or slice with like operators")
return
} else {
expr1 = fmt.Sprintf("%s %s ?", key, opr)
args = append(args, val)
}
}
exprs = append(exprs, expr1)
}
sql = strings.Join(exprs, " AND ")
return
}
func (lk Like) ToSql() (sql string, args []any, err error) {
return lk.toSql("LIKE")
}
// NotLike is syntactic sugar for use with LIKE conditions.
// Ex:
//
// .Where(NotLike{"name": "%irrel"})
type NotLike Like
func (nlk NotLike) ToSql() (sql string, args []any, err error) {
return Like(nlk).toSql("NOT LIKE")
}
// ILike is syntactic sugar for use with ILIKE conditions.
// Ex:
//
// .Where(ILike{"name": "sq%"})
type ILike Like
func (ilk ILike) ToSql() (sql string, args []any, err error) {
return Like(ilk).toSql("ILIKE")
}
// NotILike is syntactic sugar for use with ILIKE conditions.
// Ex:
//
// .Where(NotILike{"name": "sq%"})
type NotILike Like
func (nilk NotILike) ToSql() (sql string, args []any, err error) {
return Like(nilk).toSql("NOT ILIKE")
}
// Lt is syntactic sugar for use with Where/Having/Set methods.
// Ex:
//
// .Where(Lt{"id": 1})
type Lt map[string]any
func (lt Lt) toSql(opposite, orEq bool) (sql string, args []any, err error) {
var (
exprs = make([]string, 0, len(lt))
opr = "<"
)
if opposite {
opr = ">"
}
if orEq {
opr = fmt.Sprintf("%s%s", opr, "=")
}
sortedKeys := getSortedKeys(lt)
for _, key := range sortedKeys {
var expr1 string
val := lt[key]
switch v := val.(type) {
case driver.Valuer:
if val, err = v.Value(); err != nil {
return "", nil, err
}
}
if val == nil {
err = fmt.Errorf("cannot use null with less than or greater than operators")
return "", nil, err
}
if isListType(val) {
err = fmt.Errorf("cannot use array or slice with less than or greater than operators")
return "", nil, err
}
expr1 = fmt.Sprintf("%s %s ?", key, opr)
args = append(args, val)
exprs = append(exprs, expr1)
}
sql = strings.Join(exprs, " AND ")
return sql, args, nil
}
func (lt Lt) ToSql() (sql string, args []any, err error) {
return lt.toSql(false, false)
}
// LtOrEq is syntactic sugar for use with Where/Having/Set methods.
// Ex:
//
// .Where(LtOrEq{"id": 1}) == "id <= 1"
type LtOrEq Lt
func (ltOrEq LtOrEq) ToSql() (sql string, args []any, err error) {
return Lt(ltOrEq).toSql(false, true)
}
// Gt is syntactic sugar for use with Where/Having/Set methods.
// Ex:
//
// .Where(Gt{"id": 1}) == "id > 1"
type Gt Lt
func (gt Gt) ToSql() (sql string, args []any, err error) {
return Lt(gt).toSql(true, false)
}
// GtOrEq is syntactic sugar for use with Where/Having/Set methods.
// Ex:
//
// .Where(GtOrEq{"id": 1}) == "id >= 1"
type GtOrEq Lt
func (gtOrEq GtOrEq) ToSql() (sql string, args []any, err error) {
return Lt(gtOrEq).toSql(true, true)
}
type conj []Sqlizer
func (c conj) join(sep, defaultExpr string) (sql string, args []any, err error) {
if len(c) == 0 {
return defaultExpr, []any{}, nil
}
var sqlParts []string
for _, sqlizer := range c {
partSQL, partArgs, err := nestedToSql(sqlizer)
if err != nil {
return "", nil, err
}
if partSQL != "" {
sqlParts = append(sqlParts, partSQL)
args = append(args, partArgs...)
}
}
if len(sqlParts) > 0 {
sql = fmt.Sprintf("(%s)", strings.Join(sqlParts, sep))
}
return
}
// And conjunction Sqlizers
type And conj
func (a And) ToSql() (string, []any, error) {
return conj(a).join(" AND ", sqlTrue)
}
// Or conjunction Sqlizers
type Or conj
func (o Or) ToSql() (string, []any, error) {
return conj(o).join(" OR ", sqlFalse)
}
func getSortedKeys(exp map[string]any) []string {
sortedKeys := make([]string, 0, len(exp))
for k := range exp {
sortedKeys = append(sortedKeys, k)
}
sort.Strings(sortedKeys)
return sortedKeys
}
func isListType(val any) bool {
if driver.IsValue(val) {
return false
}
valVal := reflect.ValueOf(val)
return valVal.Kind() == reflect.Array || valVal.Kind() == reflect.Slice
}
// sumExpr helps to use aggregate function SUM in SQL query
type sumExpr struct {
expr Sqlizer
}
// Sum allows to use SUM function in SQL query
// Ex: SelectBuilder.Select("id", Sum("amount"))
func Sum(e Sqlizer) sumExpr {
return sumExpr{e}
}
func (e sumExpr) ToSql() (sql string, args []any, err error) {
sql, args, err = nestedToSql(e.expr)
if err == nil {
sql = fmt.Sprintf("SUM(%s)", sql)
}
return
}
// countExpr helps to use aggregate function COUNT in SQL query
type countExpr struct {
expr Sqlizer
}
// Count allows to use COUNT function in SQL query
// Ex: SelectBuilder.Select("id", Count("amount"))
func Count(e Sqlizer) countExpr {
return countExpr{e}
}
func (e countExpr) ToSql() (sql string, args []any, err error) {
sql, args, err = nestedToSql(e.expr)
if err == nil {
sql = fmt.Sprintf("COUNT(%s)", sql)
}
return
}
// minExpr helps to use aggregate function MIN in SQL query
type minExpr struct {
expr Sqlizer
}
// Min allows to use MIN function in SQL query
// Ex: SelectBuilder.Select("id", Min("amount"))
func Min(e Sqlizer) minExpr {
return minExpr{e}
}
func (e minExpr) ToSql() (sql string, args []any, err error) {
sql, args, err = nestedToSql(e.expr)
if err == nil {
sql = fmt.Sprintf("MIN(%s)", sql)
}
return
}
// maxExpr helps to use aggregate function MAX in SQL query
type maxExpr struct {
expr Sqlizer
}
// Max allows to use MAX function in SQL query
// Ex: SelectBuilder.Select("id", Max("amount"))
func Max(e Sqlizer) maxExpr {
return maxExpr{e}
}
func (e maxExpr) ToSql() (sql string, args []any, err error) {
sql, args, err = nestedToSql(e.expr)
if err == nil {
sql = fmt.Sprintf("MAX(%s)", sql)
}
return
}
// avgExpr helps to use aggregate function AVG in SQL query
type avgExpr struct {
expr Sqlizer
}
// Avg allows to use AVG function in SQL query
// Ex: SelectBuilder.Select("id", Avg("amount"))
func Avg(e Sqlizer) avgExpr {
return avgExpr{e}
}
func (e avgExpr) ToSql() (sql string, args []any, err error) {
sql, args, err = nestedToSql(e.expr)
if err == nil {
sql = fmt.Sprintf("AVG(%s)", sql)
}
return
}
// existsExpr helps to use EXISTS in SQL query
type existsExpr struct {
expr Sqlizer
}
// Exists allows to use EXISTS in SQL query
// Ex: SelectBuilder.Where(Exists(Select("id").From("accounts").Where(Eq{"id": 1})))
func Exists(e Sqlizer) existsExpr {
return existsExpr{e}
}
func (e existsExpr) ToSql() (sql string, args []any, err error) {
sql, args, err = nestedToSql(e.expr)
if err == nil {
sql = fmt.Sprintf("EXISTS (%s)", sql)
}
return
}
// notExistsExpr helps to use NOT EXISTS in SQL query
type notExistsExpr struct {
expr Sqlizer
}
// NotExists allows to use NOT EXISTS in SQL query
// Ex: SelectBuilder.Where(NotExists(Select("id").From("accounts").Where(Eq{"id": 1})))
func NotExists(e Sqlizer) notExistsExpr {
return notExistsExpr{e}
}
func (e notExistsExpr) ToSql() (sql string, args []any, err error) {
sql, args, err = nestedToSql(e.expr)
if err == nil {
sql = fmt.Sprintf("NOT EXISTS (%s)", sql)
}
return
}
// equalExpr helps to use = in SQL query
type equalExpr struct {
expr Sqlizer
value any
}
// Equal allows to use = in SQL query
// Ex: SelectBuilder.Where(Equal(sq.Select(...), 1))
func Equal(e Sqlizer, v any) equalExpr {
return equalExpr{e, v}
}
func (e equalExpr) ToSql() (sql string, args []any, err error) {
sql, args, err = nestedToSql(e.expr)
if err == nil {
sql = fmt.Sprintf("(%s) = ?", sql)
args = append(args, e.value)
}
return
}
// notEqualExpr helps to use <> in SQL query
type notEqualExpr equalExpr
// NotEqual allows to use <> in SQL query
// Ex: SelectBuilder.Where(NotEqual(sq.Select(...), 1))
func NotEqual(e Sqlizer, v any) notEqualExpr {
return notEqualExpr{e, v}
}
func (e notEqualExpr) ToSql() (sql string, args []any, err error) {
sql, args, err = nestedToSql(e.expr)
if err == nil {
sql = fmt.Sprintf("(%s) <> ?", sql)
args = append(args, e.value)
}
return
}
// greaterExpr helps to use > in SQL query
type greaterExpr equalExpr
// Greater allows to use > in SQL query
// Ex: SelectBuilder.Where(Greater(sq.Select(...), 1))
func Greater(e Sqlizer, v any) greaterExpr {
return greaterExpr{e, v}
}
func (e greaterExpr) ToSql() (sql string, args []any, err error) {
sql, args, err = nestedToSql(e.expr)
if err == nil {
sql = fmt.Sprintf("(%s) > ?", sql)
args = append(args, e.value)
}
return
}
// greaterOrEqualExpr helps to use >= in SQL query
type greaterOrEqualExpr equalExpr
// GreaterOrEqual allows to use >= in SQL query
// Ex: SelectBuilder.Where(GreaterOrEqual(sq.Select(...), 1))
func GreaterOrEqual(e Sqlizer, v any) greaterOrEqualExpr {
return greaterOrEqualExpr{e, v}
}
func (e greaterOrEqualExpr) ToSql() (sql string, args []any, err error) {
sql, args, err = nestedToSql(e.expr)
if err == nil {
sql = fmt.Sprintf("(%s) >= ?", sql)
args = append(args, e.value)
}
return
}
// lessExpr helps to use < in SQL query
type lessExpr equalExpr
// Less allows to use < in SQL query
// Ex: SelectBuilder.Where(Less(sq.Select(...), 1))
func Less(e Sqlizer, v any) lessExpr {
return lessExpr{e, v}
}
func (e lessExpr) ToSql() (sql string, args []any, err error) {
sql, args, err = nestedToSql(e.expr)
if err == nil {
sql = fmt.Sprintf("(%s) < ?", sql)
args = append(args, e.value)
}
return
}
// lessOrEqualExpr helps to use <= in SQL query
type lessOrEqualExpr equalExpr
// LessOrEqual allows to use <= in SQL query
// Ex: SelectBuilder.Where(LessOrEqual(sq.Select(...), 1))
func LessOrEqual(e Sqlizer, v any) lessOrEqualExpr {
return lessOrEqualExpr{e, v}
}
func (e lessOrEqualExpr) ToSql() (sql string, args []any, err error) {
sql, args, err = nestedToSql(e.expr)
if err == nil {
sql = fmt.Sprintf("(%s) <= ?", sql)
args = append(args, e.value)
}
return
}
// inExpr helps to use IN in SQL query
type inExpr struct {
column string
expr any
}
// In allows to use IN in SQL query
// Ex: SelectBuilder.Where(In("id", 1, 2, 3))
func In(column string, e any) inExpr {
return inExpr{column, e}
}
func (e inExpr) ToSql() (sql string, args []any, err error) {
switch v := e.expr.(type) {
case Sqlizer:
sql, args, err = nestedToSql(v)
if err == nil && sql != "" {
sql = fmt.Sprintf("%s IN (%s)", e.column, sql)
}
default:
if isListType(v) {
if reflect.ValueOf(v).Len() == 0 {
return "", nil, nil
}
if reflect.ValueOf(v).Len() == 1 {
args = []any{reflect.ValueOf(v).Index(0).Interface()}
sql = fmt.Sprintf("%s=?", e.column)
} else {
args = []any{v}
sql = fmt.Sprintf("%s=ANY(?)", e.column)
}
} else {
args = []any{v}
sql = fmt.Sprintf("%s=?", e.column)
}
}
return sql, args, err
}
// notInExpr helps to use NOT IN in SQL query
type notInExpr inExpr
// NotIn allows to use NOT IN in SQL query
// Ex: SelectBuilder.Where(NotIn("id", 1, 2, 3))
func NotIn(column string, e any) notInExpr {
return notInExpr{column, e}
}
func (e notInExpr) ToSql() (sql string, args []any, err error) {
switch v := e.expr.(type) {
case Sqlizer:
sql, args, err = nestedToSql(v)
if err == nil && sql != "" {
sql = fmt.Sprintf("%s NOT IN (%s)", e.column, sql)
}
default:
if isListType(v) {
if reflect.ValueOf(v).Len() == 0 {
return "", nil, nil
}
if reflect.ValueOf(v).Len() == 1 {
args = []any{reflect.ValueOf(v).Index(0).Interface()}
sql = fmt.Sprintf("%s<>?", e.column)
} else {
args = []any{v}
sql = fmt.Sprintf("%s<>ALL(?)", e.column)
}
} else {
args = []any{v}
sql = fmt.Sprintf("%s<>?", e.column)
}
}
return sql, args, err
}
// rangeExpr helps to use BETWEEN in SQL query
type rangeExpr struct {
column string
start any
end any
}
// Range allows to use range in SQL query
// Ex: SelectBuilder.Where(Range("id", 1, 3)) -> "id BETWEEN 1 AND 3"
// If start or end is nil, it will be omitted from the query.
// Ex: SelectBuilder.Where(Range("id", 1, nil)) -> "id >= 1"
// Ex: SelectBuilder.Where(Range("id", nil, 3)) -> "id <= 3"
func Range(column string, start, end any) rangeExpr {
return rangeExpr{column, start, end}
}
// ToSql builds the query into a SQL string and bound args.
func (e rangeExpr) ToSql() (sql string, args []any, err error) {
hasStart := e.start != nil && !reflect.ValueOf(e.start).IsZero()
hasEnd := e.end != nil && !reflect.ValueOf(e.end).IsZero()
if !hasStart && !hasEnd {
return "", nil, nil
}
var s Sqlizer
if hasStart && hasEnd {
s = Expr(fmt.Sprintf("%s BETWEEN ? AND ?", e.column), e.start, e.end)
} else if hasStart {
s = GtOrEq{e.column: e.start}
} else {
s = LtOrEq{e.column: e.end}
}
return nestedToSql(s)
}
// EqNotEmpty ignores empty and zero values in Eq map.
// Ex: EqNotEmpty{"id1": 1, "name": nil, id2: 0, "desc": ""} -> "id1 = 1".
type EqNotEmpty map[string]any
// ToSql builds the query into a SQL string and bound args.
func (eq EqNotEmpty) ToSql() (sql string, args []any, err error) {
vals := make(Eq, len(eq))
for k, v := range eq {
v = clearEmptyValue(v)
if v != nil {
vals[k] = v
}
}
return nestedToSql(vals)
}
// clearEmptyValue recursively clears empty and zero values in any type.
func clearEmptyValue(v any) any {
if v == nil {
return nil
}
t := reflect.ValueOf(v)
switch t.Kind() { //nolint:exhaustive
case reflect.Array, reflect.Slice:
if t.Len() != 0 {
newSlice := reflect.MakeSlice(t.Type(), 0, t.Len())
for i := 0; i < t.Len(); i++ {
itemVal := clearEmptyValue(t.Index(i).Interface())
if itemVal != nil {
newSlice = reflect.Append(newSlice, t.Index(i))
}
}
if newSlice.Len() != 0 {
return newSlice.Interface()
}
}
default:
if !t.IsZero() {
return v
}
}
return nil
}
type cteExpr struct {
expr Sqlizer
cte string
}
// Cte allows to define CTE (Common Table Expressions) in SQL query
func Cte(e Sqlizer, cte string) cteExpr {
return cteExpr{e, cte}
}
// ToSql builds the query into a SQL string and bound args.
func (e cteExpr) ToSql() (sql string, args []any, err error) {
sql, args, err = nestedToSql(e.expr)
if err == nil {
sql = fmt.Sprintf("%s AS (%s)", e.cte, sql)
}
return
}
type notExpr struct {
expr Sqlizer
}
// ToSql builds the query into a SQL string and bound args.
func (e notExpr) ToSql() (sql string, args []any, err error) {
sql, args, err = nestedToSql(e.expr)
if err == nil {
sql = fmt.Sprintf("NOT (%s)", sql)
}
return
}
// Not is a helper function to negate a condition.
func Not(e Sqlizer) Sqlizer {
// check nested NOT
if n, ok := e.(notExpr); ok {
return n.expr
}
return notExpr{e}
}
type coalesceExpr struct {
exprs []Sqlizer
null any
}
// Coalesce is a helper function to use COALESCE in SQL query
func Coalesce(nullValue any, exprs ...Sqlizer) Sqlizer {
return coalesceExpr{exprs, nullValue}
}
// ToSql builds the query into a SQL string and bound args.
func (e coalesceExpr) ToSql() (sql string, args []any, err error) {
exprs := make([]string, 0, len(e.exprs))
allArgs := make([]any, 0)
for _, expr := range e.exprs {
var exprSQL string
var a []any
exprSQL, a, err = nestedToSql(expr)
if err != nil {
return
}
if exprSQL == "" {
continue
}
exprs = append(exprs, fmt.Sprintf("(%s)", exprSQL))
if len(a) > 0 {
allArgs = append(allArgs, a...)
}
}
if len(exprs) == 0 {
return "", nil, nil
}
sql = fmt.Sprintf("COALESCE(%s, ?)", strings.Join(exprs, ", "))
args = append(allArgs, e.null)
return
}
package squirrel
import (
"bytes"
"errors"
"fmt"
"io"
"sort"
"strings"
"github.com/lann/builder"
)
type insertData struct {
PlaceholderFormat PlaceholderFormat
Prefixes []Sqlizer
StatementKeyword string
Options []string
Into string
Columns []string
Values [][]any
Suffixes []Sqlizer
Select *SelectBuilder
}
func (d *insertData) toSqlRaw() (sqlStr string, args []any, err error) {
if len(d.Into) == 0 {
err = errors.New("insert statements must specify a table")
return "", nil, err
}
if len(d.Values) == 0 && d.Select == nil {
err = errors.New("insert statements must have at least one set of values or select clause")
return "", nil, err
}
sql := &bytes.Buffer{}
if len(d.Prefixes) > 0 {
args, err = appendToSql(d.Prefixes, sql, " ", args)
if err != nil {
return "", nil, err
}
sql.WriteString(" ")
}
if d.StatementKeyword == "" {
_, _ = sql.WriteString("INSERT ")
} else {
_, _ = sql.WriteString(d.StatementKeyword)
_, _ = sql.WriteString(" ")
}
if len(d.Options) > 0 {
_, _ = sql.WriteString(strings.Join(d.Options, " "))
_, _ = sql.WriteString(" ")
}
_, _ = sql.WriteString("INTO ")
_, _ = sql.WriteString(d.Into)
_, _ = sql.WriteString(" ")
if len(d.Columns) > 0 {
_, _ = sql.WriteString("(")
_, _ = sql.WriteString(strings.Join(d.Columns, ","))
_, _ = sql.WriteString(") ")
}
if d.Select != nil {
args, err = d.appendSelectToSQL(sql, args)
} else {
args, err = d.appendValuesToSQL(sql, args)
}
if err != nil {
return "", nil, err
}
if len(d.Suffixes) > 0 {
sql.WriteString(" ")
args, err = appendToSql(d.Suffixes, sql, " ", args)
if err != nil {
return "", nil, err
}
}
return sql.String(), args, nil
}
func (d *insertData) ToSql() (sqlStr string, args []any, err error) {
s, a, e := d.toSqlRaw()
if e != nil {
return "", nil, e
}
sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(s)
return sqlStr, a, err
}
func (d *insertData) appendValuesToSQL(w io.Writer, args []any) ([]any, error) {
if len(d.Values) == 0 {
return args, errors.New("values for insert statements are not set")
}
_, _ = io.WriteString(w, "VALUES ")
valuesStrings := make([]string, len(d.Values))
for r, row := range d.Values {
valueStrings := make([]string, len(row))
for v, val := range row {
if vs, ok := val.(Sqlizer); ok {
vsql, vargs, err := nestedToSql(vs)
if err != nil {
return nil, err
}
valueStrings[v] = vsql
args = append(args, vargs...)
} else {
valueStrings[v] = "?"
args = append(args, val)
}
}
valuesStrings[r] = fmt.Sprintf("(%s)", strings.Join(valueStrings, ","))
}
_, _ = io.WriteString(w, strings.Join(valuesStrings, ","))
return args, nil
}
func (d *insertData) appendSelectToSQL(w io.Writer, args []any) ([]any, error) {
if d.Select == nil {
return args, errors.New("select clause for insert statements are not set")
}
selectClause, sArgs, err := d.Select.toSqlRaw()
if err != nil {
return args, err
}
_, _ = io.WriteString(w, selectClause)
args = append(args, sArgs...)
return args, nil
}
// Builder
// InsertBuilder builds SQL INSERT statements.
type InsertBuilder builder.Builder
func init() {
builder.Register(InsertBuilder{}, insertData{})
}
// Format methods
// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the
// query.
func (b InsertBuilder) PlaceholderFormat(f PlaceholderFormat) InsertBuilder {
return builder.Set(b, "PlaceholderFormat", f).(InsertBuilder)
}
// SQL methods
// ToSql builds the query into a SQL string and bound args.
func (b InsertBuilder) ToSql() (string, []any, error) {
data := builder.GetStruct(b).(insertData)
return data.ToSql()
}
// MustSql builds the query into a SQL string and bound args.
// It panics if there are any errors.
func (b InsertBuilder) MustSql() (string, []any) {
sql, args, err := b.ToSql()
if err != nil {
panic(err)
}
return sql, args
}
// Prefix adds an expression to the beginning of the query
func (b InsertBuilder) Prefix(sql string, args ...any) InsertBuilder {
return b.PrefixExpr(Expr(sql, args...))
}
// PrefixExpr adds an expression to the very beginning of the query
func (b InsertBuilder) PrefixExpr(e Sqlizer) InsertBuilder {
return builder.Append(b, "Prefixes", e).(InsertBuilder)
}
// Options adds keyword options before the INTO clause of the query.
func (b InsertBuilder) Options(options ...string) InsertBuilder {
return builder.Extend(b, "Options", options).(InsertBuilder)
}
// Into sets the INTO clause of the query.
func (b InsertBuilder) Into(from string) InsertBuilder {
return builder.Set(b, "Into", from).(InsertBuilder)
}
// Columns adds insert columns to the query.
func (b InsertBuilder) Columns(columns ...string) InsertBuilder {
return builder.Extend(b, "Columns", columns).(InsertBuilder)
}
// Values adds a single row's values to the query.
func (b InsertBuilder) Values(values ...any) InsertBuilder {
return builder.Append(b, "Values", values).(InsertBuilder)
}
// Suffix adds an expression to the end of the query
func (b InsertBuilder) Suffix(sql string, args ...any) InsertBuilder {
return b.SuffixExpr(Expr(sql, args...))
}
// SuffixExpr adds an expression to the end of the query
func (b InsertBuilder) SuffixExpr(e Sqlizer) InsertBuilder {
return builder.Append(b, "Suffixes", e).(InsertBuilder)
}
// SetMap set columns and values for insert builder from a map of column name and value
// note that it will reset all previous columns and values was set if any
func (b InsertBuilder) SetMap(clauses map[string]any) InsertBuilder {
// Keep the columns in a consistent order by sorting the column key string.
cols := make([]string, 0, len(clauses))
for col := range clauses {
cols = append(cols, col)
}
sort.Strings(cols)
vals := make([]any, 0, len(clauses))
for _, col := range cols {
vals = append(vals, clauses[col])
}
b = builder.Set(b, "Columns", cols).(InsertBuilder)
b = builder.Set(b, "Values", [][]any{vals}).(InsertBuilder)
return b
}
// Select set Select clause for insert query
// If Values and Select are used, then Select has higher priority
func (b InsertBuilder) Select(sb SelectBuilder) InsertBuilder {
return builder.Set(b, "Select", &sb).(InsertBuilder)
}
func (b InsertBuilder) statementKeyword(keyword string) InsertBuilder {
return builder.Set(b, "StatementKeyword", keyword).(InsertBuilder)
}
// toSqlRaw builds SQL with raw placeholders ("?") without applying PlaceholderFormat.
func (b InsertBuilder) toSqlRaw() (string, []any, error) {
data := builder.GetStruct(b).(insertData)
return data.toSqlRaw()
}
package squirrel
import (
"fmt"
"io"
)
type part struct {
pred any
args []any
}
func newPart(pred any, args ...any) Sqlizer {
return &part{pred, args}
}
func (p part) ToSql() (sql string, args []any, err error) {
switch pred := p.pred.(type) {
case nil:
// no-op
case Sqlizer:
sql, args, err = nestedToSql(pred)
case string:
sql = pred
args = p.args
default:
err = fmt.Errorf("expected string or Sqlizer, not %T", pred)
}
return
}
func nestedToSql(s Sqlizer) (string, []any, error) {
if raw, ok := s.(rawSqlizer); ok {
return raw.toSqlRaw()
} else {
return s.ToSql()
}
}
func appendToSql(parts []Sqlizer, w io.Writer, sep string, args []any) ([]any, error) {
for i, p := range parts {
partSql, partArgs, err := nestedToSql(p)
if err != nil {
return nil, err
} else if len(partSql) == 0 {
continue
}
if i > 0 {
_, err = io.WriteString(w, sep)
if err != nil {
return nil, err
}
}
_, err = io.WriteString(w, partSql)
if err != nil {
return nil, err
}
args = append(args, partArgs...)
}
return args, nil
}
package squirrel
import (
"bytes"
"fmt"
"strings"
)
// PlaceholderFormat is the interface that wraps the ReplacePlaceholders method.
//
// ReplacePlaceholders takes a SQL statement and replaces each question mark
// placeholder with a (possibly different) SQL placeholder.
type PlaceholderFormat interface {
ReplacePlaceholders(sql string) (string, error)
}
type placeholderDebugger interface {
debugPlaceholder() string
}
var (
// Question is a PlaceholderFormat instance that leaves placeholders as
// question marks.
Question = questionFormat{}
// Dollar is a PlaceholderFormat instance that replaces placeholders with
// dollar-prefixed positional placeholders (e.g. $1, $2, $3).
Dollar = dollarFormat{}
// Colon is a PlaceholderFormat instance that replaces placeholders with
// colon-prefixed positional placeholders (e.g. :1, :2, :3).
Colon = colonFormat{}
// AtP is a PlaceholderFormat instance that replaces placeholders with
// "@p"-prefixed positional placeholders (e.g. @p1, @p2, @p3).
AtP = atpFormat{}
)
type questionFormat struct{}
func (questionFormat) ReplacePlaceholders(sql string) (string, error) {
return sql, nil
}
func (questionFormat) debugPlaceholder() string {
return "?"
}
type dollarFormat struct{}
func (dollarFormat) ReplacePlaceholders(sql string) (string, error) {
return replacePositionalPlaceholders(sql, "$")
}
func (dollarFormat) debugPlaceholder() string {
return "$"
}
type colonFormat struct{}
func (colonFormat) ReplacePlaceholders(sql string) (string, error) {
return replacePositionalPlaceholders(sql, ":")
}
func (colonFormat) debugPlaceholder() string {
return ":"
}
type atpFormat struct{}
func (atpFormat) ReplacePlaceholders(sql string) (string, error) {
return replacePositionalPlaceholders(sql, "@p")
}
func (atpFormat) debugPlaceholder() string {
return "@p"
}
// Placeholders returns a string with count ? placeholders joined with commas.
func Placeholders(count int) string {
if count < 1 {
return ""
}
return strings.Repeat(",?", count)[1:]
}
func replacePositionalPlaceholders(sql, prefix string) (string, error) {
buf := &bytes.Buffer{}
i := 0
for {
p := strings.Index(sql, "?")
if p == -1 {
break
}
if len(sql[p:]) > 1 && sql[p:p+2] == "??" { // escape ?? => ?
buf.WriteString(sql[:p])
buf.WriteString("?")
if len(sql[p:]) == 1 {
break
}
sql = sql[p+2:]
} else {
i++
buf.WriteString(sql[:p])
fmt.Fprintf(buf, "%s%d", prefix, i)
sql = sql[p+1:]
}
}
buf.WriteString(sql)
return buf.String(), nil
}
package squirrel
import (
"bytes"
"fmt"
"strings"
"golang.org/x/exp/slices"
"github.com/lann/builder"
)
// Direction is used in OrderByDir to specify the direction of the ordering.
type Direction int
const (
Asc Direction = iota
Desc
)
type PaginatorType int
const (
PaginatorTypeUndefined PaginatorType = iota
PaginatorTypeByPage
PaginatorTypeByID
)
// Paginator is a helper object to paginate results.
type Paginator struct {
limit uint64
page uint64
lastID int64
pType PaginatorType
}
// PaginatorByPage creates a new Paginator for pagination by page.
func PaginatorByPage(pageSize, pageNum uint64) Paginator {
return Paginator{
limit: pageSize,
page: pageNum,
pType: PaginatorTypeByPage,
}
}
// PaginatorByID creates a new Paginator for pagination by ID.
func PaginatorByID(limit uint64, lastID int64) Paginator {
return Paginator{
limit: limit,
lastID: lastID,
pType: PaginatorTypeByID,
}
}
// PageSize returns the page size for PaginatorTypeByPage
func (p Paginator) PageSize() uint64 {
return p.limit
}
// PageNumber returns the page number for PaginatorTypeByPage
func (p Paginator) PageNumber() uint64 {
return p.page
}
// Limit returns the limit for PaginatorTypeByID
func (p Paginator) Limit() uint64 {
return p.limit
}
// LastID returns the last ID for PaginatorTypeByID
func (p Paginator) LastID() int64 {
return p.lastID
}
// Type returns the type of the paginator.
func (p Paginator) Type() PaginatorType {
return p.pType
}
// String returns the string representation of the direction.
func (d Direction) String() string {
if d == Asc {
return "ASC"
}
return "DESC"
}
// OrderCond is used in OrderByDir to specify the condition of the ordering.
type OrderCond struct {
ColumnID int
Direction Direction
}
type selectData struct {
PlaceholderFormat PlaceholderFormat
Prefixes []Sqlizer
Options []string
Columns []Sqlizer
From Sqlizer
Joins []Sqlizer
WhereParts []Sqlizer
GroupBys []string
HavingParts []Sqlizer
OrderByParts []Sqlizer
Limit string
Offset string
Suffixes []Sqlizer
Paginator Paginator
IDColumn string // ID column name. Required for pagination by ID.
}
func (d *selectData) ToSql() (sqlStr string, args []any, err error) {
sqlStr, args, err = d.toSqlRaw()
if err != nil {
return
}
sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sqlStr)
return
}
func (d *selectData) toSqlRaw() (sqlStr string, args []any, err error) {
if len(d.Columns) == 0 {
err = fmt.Errorf("select statements must have at least one result column")
return "", nil, err
}
sql := &bytes.Buffer{}
if len(d.Prefixes) > 0 {
args, err = appendToSql(d.Prefixes, sql, " ", args)
if err != nil {
return "", nil, err
}
_, _ = sql.WriteString(" ")
}
_, _ = sql.WriteString("SELECT ")
if len(d.Options) > 0 {
_, _ = sql.WriteString(strings.Join(d.Options, " "))
_, _ = sql.WriteString(" ")
}
if len(d.Columns) > 0 {
args, err = appendToSql(d.Columns, sql, ", ", args)
if err != nil {
return "", nil, err
}
}
if d.From != nil {
_, _ = sql.WriteString(" FROM ")
args, err = appendToSql([]Sqlizer{d.From}, sql, "", args)
if err != nil {
return "", nil, err
}
}
if len(d.Joins) > 0 {
_, _ = sql.WriteString(" ")
args, err = appendToSql(d.Joins, sql, " ", args)
if err != nil {
return "", nil, err
}
}
whereParts := make([]Sqlizer, len(d.WhereParts))
copy(whereParts, d.WhereParts)
if d.Paginator.pType == PaginatorTypeByID {
if d.IDColumn == "" {
return "", nil, fmt.Errorf("IDColumn is required for pagination by ID")
}
whereParts = append(whereParts, Gt{d.IDColumn: d.Paginator.lastID})
}
if len(whereParts) > 0 {
_, _ = sql.WriteString(" WHERE ")
args, err = appendToSql(whereParts, sql, " AND ", args)
if err != nil {
return "", nil, err
}
}
if len(d.GroupBys) > 0 {
_, _ = sql.WriteString(" GROUP BY ")
_, _ = sql.WriteString(strings.Join(d.GroupBys, ", "))
}
if len(d.HavingParts) > 0 {
_, _ = sql.WriteString(" HAVING ")
args, err = appendToSql(d.HavingParts, sql, " AND ", args)
if err != nil {
return "", nil, err
}
}
if len(d.OrderByParts) > 0 {
_, _ = sql.WriteString(" ORDER BY ")
args, err = appendToSql(d.OrderByParts, sql, ", ", args)
if err != nil {
return "", nil, err
}
}
if len(d.Limit) > 0 {
if d.Paginator.pType != PaginatorTypeUndefined {
return "", nil, fmt.Errorf("limit and paginator cannot be used together")
}
_, _ = sql.WriteString(" LIMIT ")
_, _ = sql.WriteString(d.Limit)
}
if len(d.Offset) > 0 {
if d.Paginator.pType != PaginatorTypeUndefined {
return "", nil, fmt.Errorf("offset and paginator cannot be used together")
}
_, _ = sql.WriteString(" OFFSET ")
_, _ = sql.WriteString(d.Offset)
}
switch d.Paginator.pType {
case PaginatorTypeByPage:
_, _ = fmt.Fprintf(sql, " LIMIT %d", d.Paginator.limit)
if d.Paginator.page > 1 {
_, _ = fmt.Fprintf(sql, " OFFSET %d", d.Paginator.limit*(d.Paginator.page-1))
}
case PaginatorTypeByID:
_, _ = fmt.Fprintf(sql, " LIMIT %d", d.Paginator.limit)
}
if len(d.Suffixes) > 0 {
_, _ = sql.WriteString(" ")
args, err = appendToSql(d.Suffixes, sql, " ", args)
if err != nil {
return "", nil, err
}
}
sqlStr = sql.String()
return sqlStr, args, nil
}
// Builder
// SelectBuilder builds SQL SELECT statements.
type SelectBuilder builder.Builder
func init() {
builder.Register(SelectBuilder{}, selectData{})
}
// Format methods
// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the
// query.
func (b SelectBuilder) PlaceholderFormat(f PlaceholderFormat) SelectBuilder {
return builder.Set(b, "PlaceholderFormat", f).(SelectBuilder)
}
// SQL methods
// ToSql builds the query into a SQL string and bound args.
func (b SelectBuilder) ToSql() (string, []any, error) {
data := builder.GetStruct(b).(selectData)
return data.ToSql()
}
func (b SelectBuilder) toSqlRaw() (string, []any, error) {
data := builder.GetStruct(b).(selectData)
return data.toSqlRaw()
}
// MustSql builds the query into a SQL string and bound args.
// It panics if there are any errors.
func (b SelectBuilder) MustSql() (string, []any) {
sql, args, err := b.ToSql()
if err != nil {
panic(err)
}
return sql, args
}
// Prefix adds an expression to the beginning of the query
func (b SelectBuilder) Prefix(sql string, args ...any) SelectBuilder {
return b.PrefixExpr(Expr(sql, args...))
}
// PrefixExpr adds an expression to the very beginning of the query
func (b SelectBuilder) PrefixExpr(e Sqlizer) SelectBuilder {
return builder.Append(b, "Prefixes", e).(SelectBuilder)
}
// Distinct adds a DISTINCT clause to the query.
func (b SelectBuilder) Distinct() SelectBuilder {
return b.Options("DISTINCT")
}
// Options adds select option to the query
func (b SelectBuilder) Options(options ...string) SelectBuilder {
return builder.Extend(b, "Options", options).(SelectBuilder)
}
// Columns adds result columns to the query.
func (b SelectBuilder) Columns(columns ...string) SelectBuilder {
parts := make([]any, 0, len(columns))
for _, str := range columns {
parts = append(parts, newPart(str))
}
return builder.Extend(b, "Columns", parts).(SelectBuilder)
}
// RemoveColumns remove all columns from query.
// Must add a new column with Column or Columns methods, otherwise
// return a error.
func (b SelectBuilder) RemoveColumns() SelectBuilder {
return builder.Delete(b, "Columns").(SelectBuilder)
}
// Column adds a result column to the query.
// Unlike Columns, Column accepts args which will be bound to placeholders in
// the columns string, for example:
//
// Column("IF(col IN ("+squirrel.Placeholders(3)+"), 1, 0) as col", 1, 2, 3)
func (b SelectBuilder) Column(column any, args ...any) SelectBuilder {
return builder.Append(b, "Columns", newPart(column, args...)).(SelectBuilder)
}
// From sets the FROM clause of the query.
func (b SelectBuilder) From(from string) SelectBuilder {
return builder.Set(b, "From", newPart(from)).(SelectBuilder)
}
// FromSelect sets a subquery into the FROM clause of the query.
func (b SelectBuilder) FromSelect(from SelectBuilder, alias string) SelectBuilder {
return builder.Set(b, "From", Alias(from, alias)).(SelectBuilder)
}
// JoinClause adds a join clause to the query.
func (b SelectBuilder) JoinClause(pred any, args ...any) SelectBuilder {
return builder.Append(b, "Joins", newPart(pred, args...)).(SelectBuilder)
}
// Join adds a JOIN clause to the query.
func (b SelectBuilder) Join(join string, rest ...any) SelectBuilder {
return b.JoinClause("JOIN "+join, rest...)
}
// LeftJoin adds a LEFT JOIN clause to the query.
func (b SelectBuilder) LeftJoin(join string, rest ...any) SelectBuilder {
return b.JoinClause("LEFT JOIN "+join, rest...)
}
// RightJoin adds a RIGHT JOIN clause to the query.
func (b SelectBuilder) RightJoin(join string, rest ...any) SelectBuilder {
return b.JoinClause("RIGHT JOIN "+join, rest...)
}
// InnerJoin adds a INNER JOIN clause to the query.
func (b SelectBuilder) InnerJoin(join string, rest ...any) SelectBuilder {
return b.JoinClause("INNER JOIN "+join, rest...)
}
// CrossJoin adds a CROSS JOIN clause to the query.
func (b SelectBuilder) CrossJoin(join string, rest ...any) SelectBuilder {
return b.JoinClause("CROSS JOIN "+join, rest...)
}
// Where adds an expression to the WHERE clause of the query.
//
// Expressions are ANDed together in the generated SQL.
//
// Where accepts several types for its pred argument:
//
// nil OR "" - ignored.
//
// string - SQL expression.
// If the expression has SQL placeholders then a set of arguments must be passed
// as well, one for each placeholder.
//
// map[string]any OR Eq - map of SQL expressions to values. Each key is
// transformed into an expression like "<key> = ?", with the corresponding value
// bound to the placeholder. If the value is nil, the expression will be "<key>
// IS NULL". If the value is an array or slice, the expression will be "<key> IN
// (?,?,...)", with one placeholder for each item in the value. These expressions
// are ANDed together.
//
// Where will panic if pred isn't any of the above types.
func (b SelectBuilder) Where(pred any, args ...any) SelectBuilder {
if pred == nil || pred == "" {
return b
}
return builder.Append(b, "WhereParts", newWherePart(pred, args...)).(SelectBuilder)
}
// GroupBy adds GROUP BY expressions to the query.
func (b SelectBuilder) GroupBy(groupBys ...string) SelectBuilder {
return builder.Extend(b, "GroupBys", groupBys).(SelectBuilder)
}
// Having adds an expression to the HAVING clause of the query.
//
// See Where.
func (b SelectBuilder) Having(pred any, rest ...any) SelectBuilder {
return builder.Append(b, "HavingParts", newWherePart(pred, rest...)).(SelectBuilder)
}
// OrderByClause adds ORDER BY clause to the query.
func (b SelectBuilder) OrderByClause(pred any, args ...any) SelectBuilder {
return builder.Append(b, "OrderByParts", newPart(pred, args...)).(SelectBuilder)
}
// OrderBy adds ORDER BY expressions to the query.
func (b SelectBuilder) OrderBy(orderBys ...string) SelectBuilder {
for _, orderBy := range orderBys {
b = b.OrderByClause(orderBy)
}
return b
}
// OrderNullsType is used to specify the order of NULLs in ORDER BY clause.
type OrderNullsType int
const (
OrderNullsUndefined OrderNullsType = iota
OrderNullsFirst // ORDER BY ... NULLS FIRST
OrderNullsLast // ORDER BY ... NULLS LAST
)
// String returns the string representation of the order of NULLs.
func (o OrderNullsType) String() string {
if o == OrderNullsFirst {
return "FIRST"
}
if o == OrderNullsLast {
return "LAST"
}
return ""
}
// OrderByCondOption is used to specify additional options for OrderByCond.
type OrderByCondOption struct {
ColumnID int
NullsType OrderNullsType
}
// OrderByCond adds ORDER BY expressions with direction to the query.
// The columns map is used to map OrderCond.ColumnID to the column name.
// Can be used to avoid hardcoding column names in the code.
func (b SelectBuilder) OrderByCond(columns map[int]string, conds []OrderCond, opts ...OrderByCondOption) SelectBuilder {
for i, cond := range conds {
if pos := slices.IndexFunc(conds[:i], func(c OrderCond) bool {
return c.ColumnID == cond.ColumnID
}); pos >= 0 && pos < i {
continue
}
column, ok := columns[cond.ColumnID]
if !ok {
panic(fmt.Sprintf("column id %d not found in columns map %v", cond.ColumnID, columns))
}
nullsType := OrderNullsUndefined
for _, opt := range opts {
if opt.ColumnID == cond.ColumnID {
nullsType = opt.NullsType
break
}
}
if nullsType == OrderNullsUndefined {
b = b.OrderByClause(fmt.Sprintf("%s %s", column, cond.Direction.String()))
} else {
b = b.OrderByClause(fmt.Sprintf("%s %s NULLS %s", column, cond.Direction.String(), nullsType.String()))
}
}
return b
}
// Search adds a search condition to the query.
// The search condition is a WHERE clause with LIKE expressions. All columns will be converted to text.
// value can be a string or a number.
func (b SelectBuilder) Search(value any, columns ...string) SelectBuilder {
if len(columns) == 0 {
return b
}
search := Or{}
for _, column := range columns {
search = append(search, Like{column + "::text": fmt.Sprintf("%%%v%%", value)})
}
return b.Where(search)
}
// PaginateByID adds a LIMIT and start from ID condition to the query.
// WARNING: The columnID must be included in the ORDER BY clause to avoid unexpected results!
func (b SelectBuilder) PaginateByID(limit uint64, startID int64, columnID string) SelectBuilder {
return b.Limit(limit).Where(Gt{columnID: startID})
}
// PaginateByPage adds a LIMIT and OFFSET condition to the query.
// WARNING: query must be ordered to avoid unexpected results!
func (b SelectBuilder) PaginateByPage(limit uint64, page uint64) SelectBuilder {
sb := b.Limit(limit)
if page > 1 {
sb = sb.Offset(limit * (page - 1))
}
return sb
}
// Paginate adds pagination conditions to the query.
func (b SelectBuilder) Paginate(p Paginator) SelectBuilder {
return builder.Set(b, "Paginator", p).(SelectBuilder)
}
// SetIDColumn sets the column name to be used for pagination by ID.
// Required in special cases when Paginate function combined with PaginatorByID.
func (b SelectBuilder) SetIDColumn(column string) SelectBuilder {
return builder.Set(b, "IDColumn", column).(SelectBuilder)
}
// Limit sets a LIMIT clause on the query.
func (b SelectBuilder) Limit(limit uint64) SelectBuilder {
return builder.Set(b, "Limit", fmt.Sprintf("%d", limit)).(SelectBuilder)
}
// RemoveLimit Limit ALL allows to access all records with limit
func (b SelectBuilder) RemoveLimit() SelectBuilder {
return builder.Delete(b, "Limit").(SelectBuilder)
}
// Offset sets a OFFSET clause on the query.
func (b SelectBuilder) Offset(offset uint64) SelectBuilder {
return builder.Set(b, "Offset", fmt.Sprintf("%d", offset)).(SelectBuilder)
}
// RemoveOffset removes OFFSET clause.
func (b SelectBuilder) RemoveOffset() SelectBuilder {
return builder.Delete(b, "Offset").(SelectBuilder)
}
// Suffix adds an expression to the end of the query
func (b SelectBuilder) Suffix(sql string, args ...any) SelectBuilder {
return b.SuffixExpr(Expr(sql, args...))
}
// SuffixExpr adds an expression to the end of the query
func (b SelectBuilder) SuffixExpr(e Sqlizer) SelectBuilder {
return builder.Append(b, "Suffixes", e).(SelectBuilder)
}
type alias struct {
builder SelectBuilder
table string
prefix []string
}
// Columns sets the columns for the table alias.
func (a alias) Columns(columns ...string) SelectBuilder {
if len(columns) == 0 {
return a.builder
}
return a.builder.Columns(prepareAliasColumns(a.table, a.prefix, columns...)...)
}
// GroupBy sets the group by for the table alias.
func (a alias) GroupBy(groupBys ...string) SelectBuilder {
if len(groupBys) == 0 {
return a.builder
}
return a.builder.GroupBy(prepareAliasColumns(a.table, a.prefix, groupBys...)...)
}
// OrderBy sets the order by for the table alias.
func (a alias) OrderBy(orderBys ...string) SelectBuilder {
if len(orderBys) == 0 {
return a.builder
}
return a.builder.OrderBy(prepareAliasColumns(a.table, a.prefix, orderBys...)...)
}
func prepareAliasColumns(table string, prefix []string, columns ...string) []string {
columnsPrepared := make([]string, 0, len(columns))
for _, column := range columns {
if len(prefix) == 0 {
if table == "" {
columnsPrepared = append(columnsPrepared, column)
} else {
columnsPrepared = append(columnsPrepared, fmt.Sprintf("%s.%s", table, column))
}
} else {
if table == "" {
columnsPrepared = append(columnsPrepared, fmt.Sprintf("%s AS %s_%s", column, prefix[0], column))
} else {
columnsPrepared = append(columnsPrepared, fmt.Sprintf("%s.%s AS %s_%s", table, column, prefix[0], column))
}
}
}
return columnsPrepared
}
// Alias creates a new table alias for the select builder.
// Prefix is used to add a prefix to the beginning of the column names. If no prefix, the column name will be used.
// All prefixes except the first will be ignored.
func (b SelectBuilder) Alias(table string, prefix ...string) alias {
return alias{
builder: b,
table: table,
prefix: prefix,
}
}
// With adds a CTE (Common Table Expression) to the query.
func (b SelectBuilder) With(cteName string, cte SelectBuilder) SelectBuilder {
return b.PrefixExpr(cte.Prefix(fmt.Sprintf("WITH %s AS (", cteName)).Suffix(")"))
}
// Package squirrel provides a fluent SQL generator.
//
// See https://github.com/Masterminds/squirrel for examples.
package squirrel
import (
"bytes"
"fmt"
"strings"
)
// Sqlizer is the interface that wraps the ToSql method.
//
// ToSql returns a SQL representation of the Sqlizer, along with a slice of args
// as passed to e.g. database/sql.Exec. It can also return an error.
type Sqlizer interface {
ToSql() (string, []any, error)
}
// rawSqlizer is expected to do what Sqlizer does, but without finalizing placeholders.
// This is useful for nested queries.
type rawSqlizer interface {
toSqlRaw() (string, []any, error)
}
// DebugSqlizer calls ToSql on s and shows the approximate SQL to be executed
//
// If ToSql returns an error, the result of this method will look like:
// "[ToSql error: %s]" or "[DebugSqlizer error: %s]"
//
// IMPORTANT: As its name suggests, this function should only be used for
// debugging. While the string result *might* be valid SQL, this function does
// not try very hard to ensure it. Additionally, executing the output of this
// function with any untrusted user input is certainly insecure.
func DebugSqlizer(s Sqlizer) string {
sql, args, err := s.ToSql()
if err != nil {
return fmt.Sprintf("[ToSql error: %s]", err)
}
var placeholder string
downCast, ok := s.(placeholderDebugger)
if !ok {
placeholder = "?"
} else {
placeholder = downCast.debugPlaceholder()
}
// TODO: dedupe this with placeholder.go
buf := &bytes.Buffer{}
i := 0
for {
p := strings.Index(sql, placeholder)
if p == -1 {
break
}
if len(sql[p:]) > 1 && sql[p:p+2] == "??" { // escape ?? => ?
buf.WriteString(sql[:p])
buf.WriteString("?")
if len(sql[p:]) == 1 {
break
}
sql = sql[p+2:]
} else {
if i+1 > len(args) {
return fmt.Sprintf(
"[DebugSqlizer error: too many placeholders in %#v for %d args]",
sql, len(args))
}
buf.WriteString(sql[:p])
fmt.Fprintf(buf, "'%v'", args[i])
// advance our sql string "cursor" beyond the arg we placed
sql = sql[p+1:]
i++
}
}
if i < len(args) {
return fmt.Sprintf(
"[DebugSqlizer error: not enough placeholders in %#v for %d args]",
sql, len(args))
}
// "append" any remaning sql that won't need interpolating
buf.WriteString(sql)
return buf.String()
}
package squirrel
import "github.com/lann/builder"
// StatementBuilderType is the type of StatementBuilder.
type StatementBuilderType builder.Builder
// Select returns a SelectBuilder for this StatementBuilderType.
func (b StatementBuilderType) Select(columns ...string) SelectBuilder {
return SelectBuilder(b).Columns(columns...)
}
// Insert returns a InsertBuilder for this StatementBuilderType.
func (b StatementBuilderType) Insert(into string) InsertBuilder {
return InsertBuilder(b).Into(into)
}
// Replace returns a InsertBuilder for this StatementBuilderType with the
// statement keyword set to "REPLACE".
func (b StatementBuilderType) Replace(into string) InsertBuilder {
return InsertBuilder(b).statementKeyword("REPLACE").Into(into)
}
// Update returns a UpdateBuilder for this StatementBuilderType.
func (b StatementBuilderType) Update(table string) UpdateBuilder {
return UpdateBuilder(b).Table(table)
}
// Delete returns a DeleteBuilder for this StatementBuilderType.
func (b StatementBuilderType) Delete(from string) DeleteBuilder {
return DeleteBuilder(b).From(from)
}
// With returns a CommonTableExpressionsBuilder for this StatementBuilderType
func (b StatementBuilderType) With(cte string) CommonTableExpressionsBuilder {
return CommonTableExpressionsBuilder(b).Cte(cte)
}
// PlaceholderFormat sets the PlaceholderFormat field for any child builders.
func (b StatementBuilderType) PlaceholderFormat(f PlaceholderFormat) StatementBuilderType {
return builder.Set(b, "PlaceholderFormat", f).(StatementBuilderType)
}
// Where adds WHERE expressions to the query.
//
// See SelectBuilder.Where for more information.
func (b StatementBuilderType) Where(pred any, args ...any) StatementBuilderType {
return builder.Append(b, "WhereParts", newWherePart(pred, args...)).(StatementBuilderType)
}
// StatementBuilder is a parent builder for other builders, e.g. SelectBuilder.
var StatementBuilder = StatementBuilderType(builder.EmptyBuilder).PlaceholderFormat(Question)
// Select returns a new SelectBuilder, optionally setting some result columns.
//
// See SelectBuilder.Columns.
func Select(columns ...string) SelectBuilder {
return StatementBuilder.Select(columns...)
}
// Insert returns a new InsertBuilder with the given table name.
//
// See InsertBuilder.Into.
func Insert(into string) InsertBuilder {
return StatementBuilder.Insert(into)
}
// Replace returns a new InsertBuilder with the statement keyword set to
// "REPLACE" and with the given table name.
//
// See InsertBuilder.Into.
func Replace(into string) InsertBuilder {
return StatementBuilder.Replace(into)
}
// Update returns a new UpdateBuilder with the given table name.
//
// See UpdateBuilder.Table.
func Update(table string) UpdateBuilder {
return StatementBuilder.Update(table)
}
// Delete returns a new DeleteBuilder with the given table name.
//
// See DeleteBuilder.Table.
func Delete(from string) DeleteBuilder {
return StatementBuilder.Delete(from)
}
// With returns a new CommonTableExpressionsBuilder with the given first cte name
//
// See CommonTableExpressionsBuilder.Cte
func With(cte string) CommonTableExpressionsBuilder {
return StatementBuilder.With(cte)
}
// WithRecursive returns a new CommonTableExpressionsBuilder with the RECURSIVE option and the given first cte name
//
// See CommonTableExpressionsBuilder.Cte, CommonTableExpressionsBuilder.Recursive
func WithRecursive(cte string) CommonTableExpressionsBuilder {
return StatementBuilder.With(cte).Recursive(true)
}
// Case returns a new CaseBuilder
// "what" represents case value
func Case(what ...any) CaseBuilder {
b := CaseBuilder(builder.EmptyBuilder)
switch len(what) {
case 0:
case 1:
b = b.what(what[0])
default:
b = b.what(newPart(what[0], what[1:]...))
}
return b
}
package squirrel
import (
"bytes"
"fmt"
"sort"
"strings"
"github.com/lann/builder"
)
type updateData struct {
PlaceholderFormat PlaceholderFormat
Prefixes []Sqlizer
Table string
SetClauses []setClause
From Sqlizer
WhereParts []Sqlizer
OrderBys []string
Limit string
Offset string
Suffixes []Sqlizer
}
type setClause struct {
column string
value any
}
func (d *updateData) toSqlRaw() (sqlStr string, args []any, err error) {
if len(d.Table) == 0 {
err = fmt.Errorf("update statements must specify a table")
return "", nil, err
}
if len(d.SetClauses) == 0 {
err = fmt.Errorf("update statements must have at least one Set clause")
return "", nil, err
}
sql := &bytes.Buffer{}
if len(d.Prefixes) > 0 {
args, err = appendToSql(d.Prefixes, sql, " ", args)
if err != nil {
return "", nil, err
}
_, _ = sql.WriteString(" ")
}
_, _ = sql.WriteString("UPDATE ")
_, _ = sql.WriteString(d.Table)
_, _ = sql.WriteString(" SET ")
setSqls := make([]string, len(d.SetClauses))
for i, setClause := range d.SetClauses {
var valSql string
if vs, ok := setClause.value.(Sqlizer); ok {
var (
vsql string
vargs []any
)
vsql, vargs, err = nestedToSql(vs)
if err != nil {
return "", nil, err
}
if _, ok := vs.(SelectBuilder); ok {
valSql = fmt.Sprintf("(%s)", vsql)
} else {
valSql = vsql
}
args = append(args, vargs...)
} else {
valSql = "?"
args = append(args, setClause.value)
}
setSqls[i] = fmt.Sprintf("%s = %s", setClause.column, valSql)
}
_, _ = sql.WriteString(strings.Join(setSqls, ", "))
if d.From != nil {
_, _ = sql.WriteString(" FROM ")
args, err = appendToSql([]Sqlizer{d.From}, sql, "", args)
if err != nil {
return "", nil, err
}
}
if len(d.WhereParts) > 0 {
_, _ = sql.WriteString(" WHERE ")
args, err = appendToSql(d.WhereParts, sql, " AND ", args)
if err != nil {
return "", nil, err
}
}
if len(d.OrderBys) > 0 {
_, _ = sql.WriteString(" ORDER BY ")
_, _ = sql.WriteString(strings.Join(d.OrderBys, ", "))
}
if len(d.Limit) > 0 {
_, _ = sql.WriteString(" LIMIT ")
_, _ = sql.WriteString(d.Limit)
}
if len(d.Offset) > 0 {
_, _ = sql.WriteString(" OFFSET ")
_, _ = sql.WriteString(d.Offset)
}
if len(d.Suffixes) > 0 {
_, _ = sql.WriteString(" ")
args, err = appendToSql(d.Suffixes, sql, " ", args)
if err != nil {
return "", nil, err
}
}
return sql.String(), args, nil
}
func (d *updateData) ToSql() (sqlStr string, args []any, err error) {
s, a, e := d.toSqlRaw()
if e != nil {
return "", nil, e
}
sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(s)
return sqlStr, a, err
}
// Builder
// UpdateBuilder builds SQL UPDATE statements.
type UpdateBuilder builder.Builder
func init() {
builder.Register(UpdateBuilder{}, updateData{})
}
// Format methods
// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the
// query.
func (b UpdateBuilder) PlaceholderFormat(f PlaceholderFormat) UpdateBuilder {
return builder.Set(b, "PlaceholderFormat", f).(UpdateBuilder)
}
// SQL methods
// ToSql builds the query into a SQL string and bound args.
func (b UpdateBuilder) ToSql() (string, []any, error) {
data := builder.GetStruct(b).(updateData)
return data.ToSql()
}
// MustSql builds the query into a SQL string and bound args.
// It panics if there are any errors.
func (b UpdateBuilder) MustSql() (string, []any) {
sql, args, err := b.ToSql()
if err != nil {
panic(err)
}
return sql, args
}
// Prefix adds an expression to the beginning of the query
func (b UpdateBuilder) Prefix(sql string, args ...any) UpdateBuilder {
return b.PrefixExpr(Expr(sql, args...))
}
// PrefixExpr adds an expression to the very beginning of the query
func (b UpdateBuilder) PrefixExpr(e Sqlizer) UpdateBuilder {
return builder.Append(b, "Prefixes", e).(UpdateBuilder)
}
// Table sets the table to be updated.
func (b UpdateBuilder) Table(table string) UpdateBuilder {
return builder.Set(b, "Table", table).(UpdateBuilder)
}
// Set adds SET clauses to the query.
func (b UpdateBuilder) Set(column string, value any) UpdateBuilder {
return builder.Append(b, "SetClauses", setClause{column: column, value: value}).(UpdateBuilder)
}
// SetMap is a convenience method which calls .Set for each key/value pair in clauses.
func (b UpdateBuilder) SetMap(clauses map[string]any) UpdateBuilder {
keys := make([]string, len(clauses))
i := 0
for key := range clauses {
keys[i] = key
i++
}
sort.Strings(keys)
for _, key := range keys {
val := clauses[key]
b = b.Set(key, val)
}
return b
}
// From adds FROM clause to the query
// FROM is valid construct in postgresql only.
func (b UpdateBuilder) From(from string) UpdateBuilder {
return builder.Set(b, "From", newPart(from)).(UpdateBuilder)
}
// FromSelect sets a subquery into the FROM clause of the query.
func (b UpdateBuilder) FromSelect(from SelectBuilder, alias string) UpdateBuilder {
return builder.Set(b, "From", Alias(from, alias)).(UpdateBuilder)
}
// Where adds WHERE expressions to the query.
//
// See SelectBuilder.Where for more information.
func (b UpdateBuilder) Where(pred any, args ...any) UpdateBuilder {
return builder.Append(b, "WhereParts", newWherePart(pred, args...)).(UpdateBuilder)
}
// OrderBy adds ORDER BY expressions to the query.
func (b UpdateBuilder) OrderBy(orderBys ...string) UpdateBuilder {
return builder.Extend(b, "OrderBys", orderBys).(UpdateBuilder)
}
// Limit sets a LIMIT clause on the query.
func (b UpdateBuilder) Limit(limit uint64) UpdateBuilder {
return builder.Set(b, "Limit", fmt.Sprintf("%d", limit)).(UpdateBuilder)
}
// Offset sets a OFFSET clause on the query.
func (b UpdateBuilder) Offset(offset uint64) UpdateBuilder {
return builder.Set(b, "Offset", fmt.Sprintf("%d", offset)).(UpdateBuilder)
}
// Suffix adds an expression to the end of the query
func (b UpdateBuilder) Suffix(sql string, args ...any) UpdateBuilder {
return b.SuffixExpr(Expr(sql, args...))
}
// toSqlRaw builds SQL with raw placeholders ("?") without applying PlaceholderFormat.
func (b UpdateBuilder) toSqlRaw() (string, []any, error) {
data := builder.GetStruct(b).(updateData)
return data.toSqlRaw()
}
// SuffixExpr adds an expression to the end of the query
func (b UpdateBuilder) SuffixExpr(e Sqlizer) UpdateBuilder {
return builder.Append(b, "Suffixes", e).(UpdateBuilder)
}
package squirrel
import (
"fmt"
)
type wherePart part
func newWherePart(pred any, args ...any) Sqlizer {
return &wherePart{pred: pred, args: args}
}
func (p wherePart) ToSql() (sql string, args []any, err error) {
switch pred := p.pred.(type) {
case nil:
// no-op
case rawSqlizer:
return pred.toSqlRaw()
case Sqlizer:
return pred.ToSql()
case map[string]any:
return Eq(pred).ToSql()
case string:
sql = pred
args = p.args
default:
err = fmt.Errorf("expected string-keyed map or string, not %T", pred)
}
return
}