package squirrel import ( "bytes" "errors" "fmt" "reflect" "time" "github.com/lann/builder" ) func init() { builder.Register(CaseBuilder{}, caseData{}) } // sqlizerBuffer is a helper that allows to write many Sqlizers one by one // without constant checks for errors that may come from Sqlizer type sqlizerBuffer struct { bytes.Buffer args []any err error } // WriteSql converts Sqlizer to SQL strings and writes it to buffer func (b *sqlizerBuffer) WriteSql(item Sqlizer) { if b.err != nil { return } var str string var args []any str, args, b.err = nestedToSql(item) if b.err != nil { return } _, _ = b.WriteString(str) _ = b.WriteByte(' ') b.args = append(b.args, args...) } func (b *sqlizerBuffer) ToSql() (string, []any, error) { return b.String(), b.args, b.err } // whenPart is a helper structure to describe SQLs "WHEN ... THEN ..." expression type whenPart struct { when Sqlizer then Sqlizer thenValue any nullThen bool } func newWhenPart(when any, then any) whenPart { wp := whenPart{ when: newPart(when), } switch t := then.(type) { case Sqlizer: wp.then = newPart(then) default: if t == nil { wp.nullThen = true } else { sqlName, err := sqlTypeNameHelper(reflect.TypeOf(then)) if err != nil { wp.thenValue = t } else { wp.then = newPart(Expr(fmt.Sprintf("CAST(? AS %s)", sqlName), t)) } } } return wp } func sqlTypeNameHelper(t reflect.Type) (string, error) { switch t.Kind() { //nolint:exhaustive case reflect.Bool: return "boolean", nil case reflect.Int64, reflect.Uint64, reflect.Int, reflect.Uint: return "bigint", nil case reflect.Int32, reflect.Uint32: return "integer", nil case reflect.Int16, reflect.Uint16, reflect.Int8, reflect.Uint8: return "smallint", nil case reflect.Float32, reflect.Float64: return "double precision", nil case reflect.String: return "text", nil case reflect.Struct: if t == reflect.TypeOf(time.Time{}) { return "timestamp with time zone", nil } case reflect.Slice, reflect.Array: sqlType, err := sqlTypeNameHelper(t.Elem()) if err != nil { return "", err } return sqlType + "[]", nil } return "", fmt.Errorf("unsupported type %s", t.Name()) } // caseData holds all the data required to build a CASE SQL construct type caseData struct { What Sqlizer WhenParts []whenPart Else Sqlizer ElseValue any ElseNull bool } // ToSql implements Sqlizer func (d *caseData) ToSql() (sqlStr string, args []any, err error) { if len(d.WhenParts) == 0 { return "", nil, errors.New("case expression must contain at lease one WHEN clause") } sql := sqlizerBuffer{} sql.WriteString("CASE ") if d.What != nil { sql.WriteSql(d.What) } for _, p := range d.WhenParts { sql.WriteString("WHEN ") sql.WriteSql(p.when) if p.then == nil && p.thenValue == nil && !p.nullThen { return "", nil, errors.New("When clause must have Then part") } sql.WriteString("THEN ") if p.then != nil { sql.WriteSql(p.then) } else { sql.WriteString(Placeholders(1) + " ") sql.args = append(sql.args, p.thenValue) } } if d.Else != nil || d.ElseValue != nil || d.ElseNull { sql.WriteString("ELSE ") } if d.Else != nil { sql.WriteSql(d.Else) } else if d.ElseValue != nil || d.ElseNull { sql.WriteString(Placeholders(1) + " ") sql.args = append(sql.args, d.ElseValue) } sql.WriteString("END") return sql.ToSql() } // CaseBuilder builds SQL CASE construct which could be used as parts of queries. type CaseBuilder builder.Builder // ToSql builds the query into a SQL string and bound args. func (b CaseBuilder) ToSql() (string, []any, error) { data := builder.GetStruct(b).(caseData) return data.ToSql() } // MustSql builds the query into a SQL string and bound args. // It panics if there are any errors. func (b CaseBuilder) MustSql() (string, []any) { sql, args, err := b.ToSql() if err != nil { panic(err) } return sql, args } // what sets optional value for CASE construct "CASE [value] ..." func (b CaseBuilder) what(e any) CaseBuilder { return builder.Set(b, "What", newPart(e)).(CaseBuilder) } // When adds "WHEN ... THEN ..." part to CASE construct func (b CaseBuilder) When(when any, then any) CaseBuilder { // TODO: performance hint: replace slice of WhenPart with just slice of parts // where even indices of the slice belong to "when"s and odd indices belong to "then"s return builder.Append(b, "WhenParts", newWhenPart(when, then)).(CaseBuilder) } // Else What sets optional "ELSE ..." part for CASE construct func (b CaseBuilder) Else(e any) CaseBuilder { switch e.(type) { case Sqlizer: return builder.Set(b, "Else", newPart(e)).(CaseBuilder) default: if e == nil { return builder.Set(b, "ElseNull", true).(CaseBuilder) } return builder.Set(b, "ElseValue", e).(CaseBuilder) } }
package squirrel import ( "bytes" "fmt" "github.com/lann/builder" ) // Common Table Expressions helper // e.g. // WITH cte AS ( // ... // ), cte_2 AS ( // ... // ) // SELECT ... FROM cte ... cte_2; type commonTableExpressionsData struct { PlaceholderFormat PlaceholderFormat Recursive bool CurrentCteName string Ctes []Sqlizer Statement Sqlizer } func (d *commonTableExpressionsData) toSql() (sqlStr string, args []any, err error) { if len(d.Ctes) == 0 { err = fmt.Errorf("common table expressions statements must have at least one label and subquery") return "", nil, err } if d.Statement == nil { err = fmt.Errorf("common table expressions must one of the following final statement: (select, insert, replace, update, delete)") return "", nil, err } sql := &bytes.Buffer{} _, _ = sql.WriteString("WITH ") if d.Recursive { _, _ = sql.WriteString("RECURSIVE ") } args, err = appendToSql(d.Ctes, sql, ", ", args) if err != nil { return "", nil, err } _, _ = sql.WriteString(" ") args, err = appendToSql([]Sqlizer{d.Statement}, sql, "", args) if err != nil { return "", nil, err } sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sql.String()) return sqlStr, args, err } func (d *commonTableExpressionsData) ToSql() (sql string, args []any, err error) { return d.toSql() } // Builder // CommonTableExpressionsBuilder builds CTE (Common Table Expressions) SQL statements. type CommonTableExpressionsBuilder builder.Builder func init() { builder.Register(CommonTableExpressionsBuilder{}, commonTableExpressionsData{}) } // Format methods // PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the // query. func (b CommonTableExpressionsBuilder) PlaceholderFormat(f PlaceholderFormat) CommonTableExpressionsBuilder { return builder.Set(b, "PlaceholderFormat", f).(CommonTableExpressionsBuilder) } // SQL methods // ToSql builds the query into a SQL string and bound args. func (b CommonTableExpressionsBuilder) ToSql() (string, []any, error) { data := builder.GetStruct(b).(commonTableExpressionsData) return data.ToSql() } // MustSql builds the query into a SQL string and bound args. // It panics if there are any errors. func (b CommonTableExpressionsBuilder) MustSql() (string, []any) { sql, args, err := b.ToSql() if err != nil { panic(err) } return sql, args } func (b CommonTableExpressionsBuilder) Recursive(recursive bool) CommonTableExpressionsBuilder { return builder.Set(b, "Recursive", recursive).(CommonTableExpressionsBuilder) } // Cte starts a new cte func (b CommonTableExpressionsBuilder) Cte(cte string) CommonTableExpressionsBuilder { return builder.Set(b, "CurrentCteName", cte).(CommonTableExpressionsBuilder) } // As sets the expression for the Cte func (b CommonTableExpressionsBuilder) As(as SelectBuilder) CommonTableExpressionsBuilder { data := builder.GetStruct(b).(commonTableExpressionsData) return builder.Append(b, "Ctes", cteExpr{as, data.CurrentCteName}).(CommonTableExpressionsBuilder) } // Select finalizes the CommonTableExpressionsBuilder with a SELECT func (b CommonTableExpressionsBuilder) Select(statement SelectBuilder) CommonTableExpressionsBuilder { return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder) } // Insert finalizes the CommonTableExpressionsBuilder with an INSERT func (b CommonTableExpressionsBuilder) Insert(statement InsertBuilder) CommonTableExpressionsBuilder { return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder) } // Replace finalizes the CommonTableExpressionsBuilder with a REPLACE func (b CommonTableExpressionsBuilder) Replace(statement InsertBuilder) CommonTableExpressionsBuilder { return b.Insert(statement) } // Update finalizes the CommonTableExpressionsBuilder with an UPDATE func (b CommonTableExpressionsBuilder) Update(statement UpdateBuilder) CommonTableExpressionsBuilder { return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder) } // Delete finalizes the CommonTableExpressionsBuilder with a DELETE func (b CommonTableExpressionsBuilder) Delete(statement DeleteBuilder) CommonTableExpressionsBuilder { return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder) }
package squirrel import ( "bytes" "fmt" "strings" "github.com/lann/builder" ) type deleteData struct { PlaceholderFormat PlaceholderFormat Prefixes []Sqlizer From string WhereParts []Sqlizer OrderBys []string Limit string Offset string Suffixes []Sqlizer } func (d *deleteData) ToSql() (sqlStr string, args []any, err error) { if len(d.From) == 0 { err = fmt.Errorf("delete statements must specify a From table") return "", nil, err } sql := &bytes.Buffer{} if len(d.Prefixes) > 0 { args, err = appendToSql(d.Prefixes, sql, " ", args) if err != nil { return "", nil, err } sql.WriteString(" ") } sql.WriteString("DELETE FROM ") sql.WriteString(d.From) if len(d.WhereParts) > 0 { sql.WriteString(" WHERE ") args, err = appendToSql(d.WhereParts, sql, " AND ", args) if err != nil { return "", nil, err } } if len(d.OrderBys) > 0 { _, _ = sql.WriteString(" ORDER BY ") _, _ = sql.WriteString(strings.Join(d.OrderBys, ", ")) } if len(d.Limit) > 0 { _, _ = sql.WriteString(" LIMIT ") _, _ = sql.WriteString(d.Limit) } if len(d.Offset) > 0 { _, _ = sql.WriteString(" OFFSET ") _, _ = sql.WriteString(d.Offset) } if len(d.Suffixes) > 0 { _, _ = sql.WriteString(" ") args, err = appendToSql(d.Suffixes, sql, " ", args) if err != nil { return "", nil, err } } sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sql.String()) return sqlStr, args, err } // Builder // DeleteBuilder builds SQL DELETE statements. type DeleteBuilder builder.Builder func init() { builder.Register(DeleteBuilder{}, deleteData{}) } // Format methods // PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the // query. func (b DeleteBuilder) PlaceholderFormat(f PlaceholderFormat) DeleteBuilder { return builder.Set(b, "PlaceholderFormat", f).(DeleteBuilder) } // SQL methods // ToSql builds the query into a SQL string and bound args. func (b DeleteBuilder) ToSql() (string, []any, error) { data := builder.GetStruct(b).(deleteData) return data.ToSql() } // MustSql builds the query into a SQL string and bound args. // It panics if there are any errors. func (b DeleteBuilder) MustSql() (string, []any) { sql, args, err := b.ToSql() if err != nil { panic(err) } return sql, args } // Prefix adds an expression to the beginning of the query func (b DeleteBuilder) Prefix(sql string, args ...any) DeleteBuilder { return b.PrefixExpr(Expr(sql, args...)) } // PrefixExpr adds an expression to the very beginning of the query func (b DeleteBuilder) PrefixExpr(e Sqlizer) DeleteBuilder { return builder.Append(b, "Prefixes", e).(DeleteBuilder) } // From sets the table to be deleted from. func (b DeleteBuilder) From(from string) DeleteBuilder { return builder.Set(b, "From", from).(DeleteBuilder) } // Where adds WHERE expressions to the query. // // See SelectBuilder.Where for more information. func (b DeleteBuilder) Where(pred any, args ...any) DeleteBuilder { return builder.Append(b, "WhereParts", newWherePart(pred, args...)).(DeleteBuilder) } // OrderBy adds ORDER BY expressions to the query. func (b DeleteBuilder) OrderBy(orderBys ...string) DeleteBuilder { return builder.Extend(b, "OrderBys", orderBys).(DeleteBuilder) } // Limit sets a LIMIT clause on the query. func (b DeleteBuilder) Limit(limit uint64) DeleteBuilder { return builder.Set(b, "Limit", fmt.Sprintf("%d", limit)).(DeleteBuilder) } // Offset sets a OFFSET clause on the query. func (b DeleteBuilder) Offset(offset uint64) DeleteBuilder { return builder.Set(b, "Offset", fmt.Sprintf("%d", offset)).(DeleteBuilder) } // Suffix adds an expression to the end of the query func (b DeleteBuilder) Suffix(sql string, args ...any) DeleteBuilder { return b.SuffixExpr(Expr(sql, args...)) } // SuffixExpr adds an expression to the end of the query func (b DeleteBuilder) SuffixExpr(e Sqlizer) DeleteBuilder { return builder.Append(b, "Suffixes", e).(DeleteBuilder) }
package squirrel import ( "bytes" "database/sql/driver" "fmt" "reflect" "sort" "strings" ) const ( // Portable true/false literals. sqlTrue = "(1=1)" sqlFalse = "(1=0)" ) type expr struct { sql string args []any } // Expr builds an expression from a SQL fragment and arguments. // // Ex: // // Expr("FROM_UNIXTIME(?)", t) func Expr(sql string, args ...any) Sqlizer { return expr{sql: sql, args: args} } func (e expr) ToSql() (sql string, args []any, err error) { simple := true for _, arg := range e.args { if _, ok := arg.(Sqlizer); ok { simple = false } if isListType(arg) { simple = false } } if simple { return e.sql, e.args, nil } buf := &bytes.Buffer{} ap := e.args sp := e.sql var isql string var iargs []any for err == nil && len(ap) > 0 && len(sp) > 0 { i := strings.Index(sp, "?") if i < 0 { // no more placeholders break } if len(sp) > i+1 && sp[i+1:i+2] == "?" { // escaped "??"; append it and step past buf.WriteString(sp[:i+2]) sp = sp[i+2:] continue } if as, ok := ap[0].(Sqlizer); ok { // sqlizer argument; expand it and append the result isql, iargs, err = as.ToSql() buf.WriteString(sp[:i]) buf.WriteString(isql) args = append(args, iargs...) } else { // normal argument; append it and the placeholder buf.WriteString(sp[:i+1]) args = append(args, ap[0]) } // step past the argument and placeholder ap = ap[1:] sp = sp[i+1:] } // append the remaining sql and arguments buf.WriteString(sp) return buf.String(), append(args, ap...), err } type concatExpr []any func (ce concatExpr) ToSql() (sql string, args []any, err error) { for _, part := range ce { switch p := part.(type) { case string: sql += p case Sqlizer: pSql, pArgs, err := p.ToSql() if err != nil { return "", nil, err } sql += pSql args = append(args, pArgs...) default: return "", nil, fmt.Errorf("%#v is not a string or Sqlizer", part) } } return } // ConcatExpr builds an expression by concatenating strings and other expressions. // // Ex: // // name_expr := Expr("CONCAT(?, ' ', ?)", firstName, lastName) // ConcatExpr("COALESCE(full_name,", name_expr, ")") func ConcatExpr(parts ...any) concatExpr { return concatExpr(parts) } // aliasExpr helps to alias part of SQL query generated with underlying "expr" type aliasExpr struct { expr Sqlizer alias string } // Alias allows to define alias for column in SelectBuilder. Useful when column is // defined as complex expression like IF or CASE // Ex: // // .Column(Alias(caseStmt, "case_column")) func Alias(e Sqlizer, a string) aliasExpr { return aliasExpr{e, a} } func (e aliasExpr) ToSql() (sql string, args []any, err error) { sql, args, err = e.expr.ToSql() if err == nil { sql = fmt.Sprintf("(%s) AS %s", sql, e.alias) } return } // Eq is syntactic sugar for use with Where/Having/Set methods. type Eq map[string]any func (eq Eq) toSQL(useNotOpr bool) (sql string, args []any, err error) { if len(eq) == 0 { // Empty Sql{} evaluates to true. sql = sqlTrue return sql, args, nil } var ( exprs = make([]string, 0, len(eq)) equalOpr = "=" inOpr = "IN" nullOpr = "IS" inEmptyExpr = sqlFalse ) if useNotOpr { equalOpr = "<>" inOpr = "NOT IN" nullOpr = "IS NOT" inEmptyExpr = sqlTrue } sortedKeys := getSortedKeys(eq) for _, key := range sortedKeys { var expr1 string val := eq[key] switch v := val.(type) { case driver.Valuer: if val, err = v.Value(); err != nil { return "", nil, err } } r := reflect.ValueOf(val) if r.Kind() == reflect.Ptr { if r.IsNil() { val = nil } else { val = r.Elem().Interface() } } if val == nil { expr1 = fmt.Sprintf("%s %s NULL", key, nullOpr) } else { if isListType(val) { valVal := reflect.ValueOf(val) if valVal.Len() == 0 { expr1 = inEmptyExpr if args == nil { args = []any{} } } else { for i := 0; i < valVal.Len(); i++ { args = append(args, valVal.Index(i).Interface()) } expr1 = fmt.Sprintf("%s %s (%s)", key, inOpr, Placeholders(valVal.Len())) } } else if sb, ok := val.(SelectBuilder); ok { var ( subSql string subArgs []any ) subSql, subArgs, err = sb.toSqlRaw() if err != nil { return "", nil, err } expr1 = fmt.Sprintf("%s %s (%s)", key, inOpr, subSql) args = append(args, subArgs...) } else { expr1 = fmt.Sprintf("%s %s ?", key, equalOpr) args = append(args, val) } } exprs = append(exprs, expr1) } sql = strings.Join(exprs, " AND ") return sql, args, nil } func (eq Eq) ToSql() (sql string, args []any, err error) { return eq.toSQL(false) } // NotEq is syntactic sugar for use with Where/Having/Set methods. // Ex: // // .Where(NotEq{"id": 1}) == "id <> 1" type NotEq Eq func (neq NotEq) ToSql() (sql string, args []any, err error) { return Eq(neq).toSQL(true) } // Like is syntactic sugar for use with LIKE conditions. // Ex: // // .Where(Like{"name": "%irrel"}) type Like map[string]any func (lk Like) toSql(opr string) (sql string, args []any, err error) { exprs := make([]string, 0, len(lk)) for key, val := range lk { var expr1 string switch v := val.(type) { case driver.Valuer: if val, err = v.Value(); err != nil { return } } if val == nil { err = fmt.Errorf("cannot use null with like operators") return } else { if isListType(val) { err = fmt.Errorf("cannot use array or slice with like operators") return } else { expr1 = fmt.Sprintf("%s %s ?", key, opr) args = append(args, val) } } exprs = append(exprs, expr1) } sql = strings.Join(exprs, " AND ") return } func (lk Like) ToSql() (sql string, args []any, err error) { return lk.toSql("LIKE") } // NotLike is syntactic sugar for use with LIKE conditions. // Ex: // // .Where(NotLike{"name": "%irrel"}) type NotLike Like func (nlk NotLike) ToSql() (sql string, args []any, err error) { return Like(nlk).toSql("NOT LIKE") } // ILike is syntactic sugar for use with ILIKE conditions. // Ex: // // .Where(ILike{"name": "sq%"}) type ILike Like func (ilk ILike) ToSql() (sql string, args []any, err error) { return Like(ilk).toSql("ILIKE") } // NotILike is syntactic sugar for use with ILIKE conditions. // Ex: // // .Where(NotILike{"name": "sq%"}) type NotILike Like func (nilk NotILike) ToSql() (sql string, args []any, err error) { return Like(nilk).toSql("NOT ILIKE") } // Lt is syntactic sugar for use with Where/Having/Set methods. // Ex: // // .Where(Lt{"id": 1}) type Lt map[string]any func (lt Lt) toSql(opposite, orEq bool) (sql string, args []any, err error) { var ( exprs = make([]string, 0, len(lt)) opr = "<" ) if opposite { opr = ">" } if orEq { opr = fmt.Sprintf("%s%s", opr, "=") } sortedKeys := getSortedKeys(lt) for _, key := range sortedKeys { var expr1 string val := lt[key] switch v := val.(type) { case driver.Valuer: if val, err = v.Value(); err != nil { return "", nil, err } } if val == nil { err = fmt.Errorf("cannot use null with less than or greater than operators") return "", nil, err } if isListType(val) { err = fmt.Errorf("cannot use array or slice with less than or greater than operators") return "", nil, err } expr1 = fmt.Sprintf("%s %s ?", key, opr) args = append(args, val) exprs = append(exprs, expr1) } sql = strings.Join(exprs, " AND ") return sql, args, nil } func (lt Lt) ToSql() (sql string, args []any, err error) { return lt.toSql(false, false) } // LtOrEq is syntactic sugar for use with Where/Having/Set methods. // Ex: // // .Where(LtOrEq{"id": 1}) == "id <= 1" type LtOrEq Lt func (ltOrEq LtOrEq) ToSql() (sql string, args []any, err error) { return Lt(ltOrEq).toSql(false, true) } // Gt is syntactic sugar for use with Where/Having/Set methods. // Ex: // // .Where(Gt{"id": 1}) == "id > 1" type Gt Lt func (gt Gt) ToSql() (sql string, args []any, err error) { return Lt(gt).toSql(true, false) } // GtOrEq is syntactic sugar for use with Where/Having/Set methods. // Ex: // // .Where(GtOrEq{"id": 1}) == "id >= 1" type GtOrEq Lt func (gtOrEq GtOrEq) ToSql() (sql string, args []any, err error) { return Lt(gtOrEq).toSql(true, true) } type conj []Sqlizer func (c conj) join(sep, defaultExpr string) (sql string, args []any, err error) { if len(c) == 0 { return defaultExpr, []any{}, nil } var sqlParts []string for _, sqlizer := range c { partSQL, partArgs, err := nestedToSql(sqlizer) if err != nil { return "", nil, err } if partSQL != "" { sqlParts = append(sqlParts, partSQL) args = append(args, partArgs...) } } if len(sqlParts) > 0 { sql = fmt.Sprintf("(%s)", strings.Join(sqlParts, sep)) } return } // And conjunction Sqlizers type And conj func (a And) ToSql() (string, []any, error) { return conj(a).join(" AND ", sqlTrue) } // Or conjunction Sqlizers type Or conj func (o Or) ToSql() (string, []any, error) { return conj(o).join(" OR ", sqlFalse) } func getSortedKeys(exp map[string]any) []string { sortedKeys := make([]string, 0, len(exp)) for k := range exp { sortedKeys = append(sortedKeys, k) } sort.Strings(sortedKeys) return sortedKeys } func isListType(val any) bool { if driver.IsValue(val) { return false } valVal := reflect.ValueOf(val) return valVal.Kind() == reflect.Array || valVal.Kind() == reflect.Slice } // sumExpr helps to use aggregate function SUM in SQL query type sumExpr struct { expr Sqlizer } // Sum allows to use SUM function in SQL query // Ex: SelectBuilder.Select("id", Sum("amount")) func Sum(e Sqlizer) sumExpr { return sumExpr{e} } func (e sumExpr) ToSql() (sql string, args []any, err error) { sql, args, err = e.expr.ToSql() if err == nil { sql = fmt.Sprintf("SUM(%s)", sql) } return } // countExpr helps to use aggregate function COUNT in SQL query type countExpr struct { expr Sqlizer } // Count allows to use COUNT function in SQL query // Ex: SelectBuilder.Select("id", Count("amount")) func Count(e Sqlizer) countExpr { return countExpr{e} } func (e countExpr) ToSql() (sql string, args []any, err error) { sql, args, err = e.expr.ToSql() if err == nil { sql = fmt.Sprintf("COUNT(%s)", sql) } return } // minExpr helps to use aggregate function MIN in SQL query type minExpr struct { expr Sqlizer } // Min allows to use MIN function in SQL query // Ex: SelectBuilder.Select("id", Min("amount")) func Min(e Sqlizer) minExpr { return minExpr{e} } func (e minExpr) ToSql() (sql string, args []any, err error) { sql, args, err = e.expr.ToSql() if err == nil { sql = fmt.Sprintf("MIN(%s)", sql) } return } // maxExpr helps to use aggregate function MAX in SQL query type maxExpr struct { expr Sqlizer } // Max allows to use MAX function in SQL query // Ex: SelectBuilder.Select("id", Max("amount")) func Max(e Sqlizer) maxExpr { return maxExpr{e} } func (e maxExpr) ToSql() (sql string, args []any, err error) { sql, args, err = e.expr.ToSql() if err == nil { sql = fmt.Sprintf("MAX(%s)", sql) } return } // avgExpr helps to use aggregate function AVG in SQL query type avgExpr struct { expr Sqlizer } // Avg allows to use AVG function in SQL query // Ex: SelectBuilder.Select("id", Avg("amount")) func Avg(e Sqlizer) avgExpr { return avgExpr{e} } func (e avgExpr) ToSql() (sql string, args []any, err error) { sql, args, err = e.expr.ToSql() if err == nil { sql = fmt.Sprintf("AVG(%s)", sql) } return } // existsExpr helps to use EXISTS in SQL query type existsExpr struct { expr Sqlizer } // Exists allows to use EXISTS in SQL query // Ex: SelectBuilder.Where(Exists(Select("id").From("accounts").Where(Eq{"id": 1}))) func Exists(e Sqlizer) existsExpr { return existsExpr{e} } func (e existsExpr) ToSql() (sql string, args []any, err error) { sql, args, err = e.expr.ToSql() if err == nil { sql = fmt.Sprintf("EXISTS (%s)", sql) } return } // notExistsExpr helps to use NOT EXISTS in SQL query type notExistsExpr struct { expr Sqlizer } // NotExists allows to use NOT EXISTS in SQL query // Ex: SelectBuilder.Where(NotExists(Select("id").From("accounts").Where(Eq{"id": 1}))) func NotExists(e Sqlizer) notExistsExpr { return notExistsExpr{e} } func (e notExistsExpr) ToSql() (sql string, args []any, err error) { sql, args, err = e.expr.ToSql() if err == nil { sql = fmt.Sprintf("NOT EXISTS (%s)", sql) } return } // equalExpr helps to use = in SQL query type equalExpr struct { expr Sqlizer value any } // Equal allows to use = in SQL query // Ex: SelectBuilder.Where(Equal(sq.Select(...), 1)) func Equal(e Sqlizer, v any) equalExpr { return equalExpr{e, v} } func (e equalExpr) ToSql() (sql string, args []any, err error) { sql, args, err = e.expr.ToSql() if err == nil { sql = fmt.Sprintf("(%s) = ?", sql) args = append(args, e.value) } return } // notEqualExpr helps to use <> in SQL query type notEqualExpr equalExpr // NotEqual allows to use <> in SQL query // Ex: SelectBuilder.Where(NotEqual(sq.Select(...), 1)) func NotEqual(e Sqlizer, v any) notEqualExpr { return notEqualExpr{e, v} } func (e notEqualExpr) ToSql() (sql string, args []any, err error) { sql, args, err = e.expr.ToSql() if err == nil { sql = fmt.Sprintf("(%s) <> ?", sql) args = append(args, e.value) } return } // greaterExpr helps to use > in SQL query type greaterExpr equalExpr // Greater allows to use > in SQL query // Ex: SelectBuilder.Where(Greater(sq.Select(...), 1)) func Greater(e Sqlizer, v any) greaterExpr { return greaterExpr{e, v} } func (e greaterExpr) ToSql() (sql string, args []any, err error) { sql, args, err = e.expr.ToSql() if err == nil { sql = fmt.Sprintf("(%s) > ?", sql) args = append(args, e.value) } return } // greaterOrEqualExpr helps to use >= in SQL query type greaterOrEqualExpr equalExpr // GreaterOrEqual allows to use >= in SQL query // Ex: SelectBuilder.Where(GreaterOrEqual(sq.Select(...), 1)) func GreaterOrEqual(e Sqlizer, v any) greaterOrEqualExpr { return greaterOrEqualExpr{e, v} } func (e greaterOrEqualExpr) ToSql() (sql string, args []any, err error) { sql, args, err = e.expr.ToSql() if err == nil { sql = fmt.Sprintf("(%s) >= ?", sql) args = append(args, e.value) } return } // lessExpr helps to use < in SQL query type lessExpr equalExpr // Less allows to use < in SQL query // Ex: SelectBuilder.Where(Less(sq.Select(...), 1)) func Less(e Sqlizer, v any) lessExpr { return lessExpr{e, v} } func (e lessExpr) ToSql() (sql string, args []any, err error) { sql, args, err = e.expr.ToSql() if err == nil { sql = fmt.Sprintf("(%s) < ?", sql) args = append(args, e.value) } return } // lessOrEqualExpr helps to use <= in SQL query type lessOrEqualExpr equalExpr // LessOrEqual allows to use <= in SQL query // Ex: SelectBuilder.Where(LessOrEqual(sq.Select(...), 1)) func LessOrEqual(e Sqlizer, v any) lessOrEqualExpr { return lessOrEqualExpr{e, v} } func (e lessOrEqualExpr) ToSql() (sql string, args []any, err error) { sql, args, err = e.expr.ToSql() if err == nil { sql = fmt.Sprintf("(%s) <= ?", sql) args = append(args, e.value) } return } // inExpr helps to use IN in SQL query type inExpr struct { column string expr any } // In allows to use IN in SQL query // Ex: SelectBuilder.Where(In("id", 1, 2, 3)) func In(column string, e any) inExpr { return inExpr{column, e} } func (e inExpr) ToSql() (sql string, args []any, err error) { switch v := e.expr.(type) { case Sqlizer: sql, args, err = v.ToSql() if err == nil && sql != "" { sql = fmt.Sprintf("%s IN (%s)", e.column, sql) } default: if isListType(v) { if reflect.ValueOf(v).Len() == 0 { return "", nil, nil } if reflect.ValueOf(v).Len() == 1 { args = []any{reflect.ValueOf(v).Index(0).Interface()} sql = fmt.Sprintf("%s=?", e.column) } else { args = []any{v} sql = fmt.Sprintf("%s=ANY(?)", e.column) } } else { args = []any{v} sql = fmt.Sprintf("%s=?", e.column) } } return sql, args, err } // notInExpr helps to use NOT IN in SQL query type notInExpr inExpr // NotIn allows to use NOT IN in SQL query // Ex: SelectBuilder.Where(NotIn("id", 1, 2, 3)) func NotIn(column string, e any) notInExpr { return notInExpr{column, e} } func (e notInExpr) ToSql() (sql string, args []any, err error) { switch v := e.expr.(type) { case Sqlizer: sql, args, err = v.ToSql() if err == nil && sql != "" { sql = fmt.Sprintf("%s NOT IN (%s)", e.column, sql) } default: if isListType(v) { if reflect.ValueOf(v).Len() == 0 { return "", nil, nil } if reflect.ValueOf(v).Len() == 1 { args = []any{reflect.ValueOf(v).Index(0).Interface()} sql = fmt.Sprintf("%s<>?", e.column) } else { args = []any{v} sql = fmt.Sprintf("%s<>ALL(?)", e.column) } } else { args = []any{v} sql = fmt.Sprintf("%s<>?", e.column) } } return sql, args, err } // rangeExpr helps to use BETWEEN in SQL query type rangeExpr struct { column string start any end any } // Range allows to use range in SQL query // Ex: SelectBuilder.Where(Range("id", 1, 3)) -> "id BETWEEN 1 AND 3" // If start or end is nil, it will be omitted from the query. // Ex: SelectBuilder.Where(Range("id", 1, nil)) -> "id >= 1" // Ex: SelectBuilder.Where(Range("id", nil, 3)) -> "id <= 3" func Range(column string, start, end any) rangeExpr { return rangeExpr{column, start, end} } // ToSql builds the query into a SQL string and bound args. func (e rangeExpr) ToSql() (sql string, args []any, err error) { hasStart := e.start != nil && !reflect.ValueOf(e.start).IsZero() hasEnd := e.end != nil && !reflect.ValueOf(e.end).IsZero() if !hasStart && !hasEnd { return "", nil, nil } var s Sqlizer if hasStart && hasEnd { s = Expr(fmt.Sprintf("%s BETWEEN ? AND ?", e.column), e.start, e.end) } else if hasStart { s = GtOrEq{e.column: e.start} } else { s = LtOrEq{e.column: e.end} } return s.ToSql() } // EqNotEmpty ignores empty and zero values in Eq map. // Ex: EqNotEmpty{"id1": 1, "name": nil, id2: 0, "desc": ""} -> "id1 = 1". type EqNotEmpty map[string]any // ToSql builds the query into a SQL string and bound args. func (eq EqNotEmpty) ToSql() (sql string, args []any, err error) { vals := make(Eq, len(eq)) for k, v := range eq { v = clearEmptyValue(v) if v != nil { vals[k] = v } } return vals.ToSql() } // clearEmptyValue recursively clears empty and zero values in any type. func clearEmptyValue(v any) any { if v == nil { return nil } t := reflect.ValueOf(v) switch t.Kind() { //nolint:exhaustive case reflect.Array, reflect.Slice: if t.Len() != 0 { newSlice := reflect.MakeSlice(t.Type(), 0, t.Len()) for i := 0; i < t.Len(); i++ { itemVal := clearEmptyValue(t.Index(i).Interface()) if itemVal != nil { newSlice = reflect.Append(newSlice, t.Index(i)) } } if newSlice.Len() != 0 { return newSlice.Interface() } } default: if !t.IsZero() { return v } } return nil } type cteExpr struct { expr Sqlizer cte string } // Cte allows to define CTE (Common Table Expressions) in SQL query func Cte(e Sqlizer, cte string) cteExpr { return cteExpr{e, cte} } // ToSql builds the query into a SQL string and bound args. func (e cteExpr) ToSql() (sql string, args []any, err error) { sql, args, err = e.expr.ToSql() if err == nil { sql = fmt.Sprintf("%s AS (%s)", e.cte, sql) } return } type notExpr struct { expr Sqlizer } // ToSql builds the query into a SQL string and bound args. func (e notExpr) ToSql() (sql string, args []any, err error) { sql, args, err = e.expr.ToSql() if err == nil { sql = fmt.Sprintf("NOT (%s)", sql) } return } // Not is a helper function to negate a condition. func Not(e Sqlizer) Sqlizer { // check nested NOT if n, ok := e.(notExpr); ok { return n.expr } return notExpr{e} } type coalesceExpr struct { exprs []Sqlizer null any } // Coalesce is a helper function to use COALESCE in SQL query func Coalesce(nullValue any, exprs ...Sqlizer) Sqlizer { return coalesceExpr{exprs, nullValue} } // ToSql builds the query into a SQL string and bound args. func (e coalesceExpr) ToSql() (sql string, args []any, err error) { exprs := make([]string, 0, len(e.exprs)) for _, expr := range e.exprs { var exprSQL string exprSQL, args, err = expr.ToSql() if err != nil { return } if exprSQL == "" { continue } exprs = append(exprs, fmt.Sprintf("(%s)", exprSQL)) } if len(exprs) == 0 { return "", nil, nil } sql = fmt.Sprintf("COALESCE(%s, ?)", strings.Join(exprs, ", ")) args = append(args, e.null) return }
package squirrel import ( "bytes" "errors" "fmt" "io" "sort" "strings" "github.com/lann/builder" ) type insertData struct { PlaceholderFormat PlaceholderFormat Prefixes []Sqlizer StatementKeyword string Options []string Into string Columns []string Values [][]any Suffixes []Sqlizer Select *SelectBuilder } func (d *insertData) ToSql() (sqlStr string, args []any, err error) { if len(d.Into) == 0 { err = errors.New("insert statements must specify a table") return "", nil, err } if len(d.Values) == 0 && d.Select == nil { err = errors.New("insert statements must have at least one set of values or select clause") return "", nil, err } sql := &bytes.Buffer{} if len(d.Prefixes) > 0 { args, err = appendToSql(d.Prefixes, sql, " ", args) if err != nil { return "", nil, err } sql.WriteString(" ") } if d.StatementKeyword == "" { _, _ = sql.WriteString("INSERT ") } else { _, _ = sql.WriteString(d.StatementKeyword) _, _ = sql.WriteString(" ") } if len(d.Options) > 0 { _, _ = sql.WriteString(strings.Join(d.Options, " ")) _, _ = sql.WriteString(" ") } _, _ = sql.WriteString("INTO ") _, _ = sql.WriteString(d.Into) _, _ = sql.WriteString(" ") if len(d.Columns) > 0 { _, _ = sql.WriteString("(") _, _ = sql.WriteString(strings.Join(d.Columns, ",")) _, _ = sql.WriteString(") ") } if d.Select != nil { args, err = d.appendSelectToSQL(sql, args) } else { args, err = d.appendValuesToSQL(sql, args) } if err != nil { return "", nil, err } if len(d.Suffixes) > 0 { sql.WriteString(" ") args, err = appendToSql(d.Suffixes, sql, " ", args) if err != nil { return "", nil, err } } sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sql.String()) return sqlStr, args, err } func (d *insertData) appendValuesToSQL(w io.Writer, args []any) ([]any, error) { if len(d.Values) == 0 { return args, errors.New("values for insert statements are not set") } _, _ = io.WriteString(w, "VALUES ") valuesStrings := make([]string, len(d.Values)) for r, row := range d.Values { valueStrings := make([]string, len(row)) for v, val := range row { if vs, ok := val.(Sqlizer); ok { vsql, vargs, err := vs.ToSql() if err != nil { return nil, err } valueStrings[v] = vsql args = append(args, vargs...) } else { valueStrings[v] = "?" args = append(args, val) } } valuesStrings[r] = fmt.Sprintf("(%s)", strings.Join(valueStrings, ",")) } _, _ = io.WriteString(w, strings.Join(valuesStrings, ",")) return args, nil } func (d *insertData) appendSelectToSQL(w io.Writer, args []any) ([]any, error) { if d.Select == nil { return args, errors.New("select clause for insert statements are not set") } selectClause, sArgs, err := d.Select.ToSql() if err != nil { return args, err } _, _ = io.WriteString(w, selectClause) args = append(args, sArgs...) return args, nil } // Builder // InsertBuilder builds SQL INSERT statements. type InsertBuilder builder.Builder func init() { builder.Register(InsertBuilder{}, insertData{}) } // Format methods // PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the // query. func (b InsertBuilder) PlaceholderFormat(f PlaceholderFormat) InsertBuilder { return builder.Set(b, "PlaceholderFormat", f).(InsertBuilder) } // SQL methods // ToSql builds the query into a SQL string and bound args. func (b InsertBuilder) ToSql() (string, []any, error) { data := builder.GetStruct(b).(insertData) return data.ToSql() } // MustSql builds the query into a SQL string and bound args. // It panics if there are any errors. func (b InsertBuilder) MustSql() (string, []any) { sql, args, err := b.ToSql() if err != nil { panic(err) } return sql, args } // Prefix adds an expression to the beginning of the query func (b InsertBuilder) Prefix(sql string, args ...any) InsertBuilder { return b.PrefixExpr(Expr(sql, args...)) } // PrefixExpr adds an expression to the very beginning of the query func (b InsertBuilder) PrefixExpr(e Sqlizer) InsertBuilder { return builder.Append(b, "Prefixes", e).(InsertBuilder) } // Options adds keyword options before the INTO clause of the query. func (b InsertBuilder) Options(options ...string) InsertBuilder { return builder.Extend(b, "Options", options).(InsertBuilder) } // Into sets the INTO clause of the query. func (b InsertBuilder) Into(from string) InsertBuilder { return builder.Set(b, "Into", from).(InsertBuilder) } // Columns adds insert columns to the query. func (b InsertBuilder) Columns(columns ...string) InsertBuilder { return builder.Extend(b, "Columns", columns).(InsertBuilder) } // Values adds a single row's values to the query. func (b InsertBuilder) Values(values ...any) InsertBuilder { return builder.Append(b, "Values", values).(InsertBuilder) } // Suffix adds an expression to the end of the query func (b InsertBuilder) Suffix(sql string, args ...any) InsertBuilder { return b.SuffixExpr(Expr(sql, args...)) } // SuffixExpr adds an expression to the end of the query func (b InsertBuilder) SuffixExpr(e Sqlizer) InsertBuilder { return builder.Append(b, "Suffixes", e).(InsertBuilder) } // SetMap set columns and values for insert builder from a map of column name and value // note that it will reset all previous columns and values was set if any func (b InsertBuilder) SetMap(clauses map[string]any) InsertBuilder { // Keep the columns in a consistent order by sorting the column key string. cols := make([]string, 0, len(clauses)) for col := range clauses { cols = append(cols, col) } sort.Strings(cols) vals := make([]any, 0, len(clauses)) for _, col := range cols { vals = append(vals, clauses[col]) } b = builder.Set(b, "Columns", cols).(InsertBuilder) b = builder.Set(b, "Values", [][]any{vals}).(InsertBuilder) return b } // Select set Select clause for insert query // If Values and Select are used, then Select has higher priority func (b InsertBuilder) Select(sb SelectBuilder) InsertBuilder { return builder.Set(b, "Select", &sb).(InsertBuilder) } func (b InsertBuilder) statementKeyword(keyword string) InsertBuilder { return builder.Set(b, "StatementKeyword", keyword).(InsertBuilder) }
package squirrel import ( "fmt" "io" ) type part struct { pred any args []any } func newPart(pred any, args ...any) Sqlizer { return &part{pred, args} } func (p part) ToSql() (sql string, args []any, err error) { switch pred := p.pred.(type) { case nil: // no-op case Sqlizer: sql, args, err = nestedToSql(pred) case string: sql = pred args = p.args default: err = fmt.Errorf("expected string or Sqlizer, not %T", pred) } return } func nestedToSql(s Sqlizer) (string, []any, error) { if raw, ok := s.(rawSqlizer); ok { return raw.toSqlRaw() } else { return s.ToSql() } } func appendToSql(parts []Sqlizer, w io.Writer, sep string, args []any) ([]any, error) { for i, p := range parts { partSql, partArgs, err := nestedToSql(p) if err != nil { return nil, err } else if len(partSql) == 0 { continue } if i > 0 { _, err = io.WriteString(w, sep) if err != nil { return nil, err } } _, err = io.WriteString(w, partSql) if err != nil { return nil, err } args = append(args, partArgs...) } return args, nil }
package squirrel import ( "bytes" "fmt" "strings" ) // PlaceholderFormat is the interface that wraps the ReplacePlaceholders method. // // ReplacePlaceholders takes a SQL statement and replaces each question mark // placeholder with a (possibly different) SQL placeholder. type PlaceholderFormat interface { ReplacePlaceholders(sql string) (string, error) } type placeholderDebugger interface { debugPlaceholder() string } var ( // Question is a PlaceholderFormat instance that leaves placeholders as // question marks. Question = questionFormat{} // Dollar is a PlaceholderFormat instance that replaces placeholders with // dollar-prefixed positional placeholders (e.g. $1, $2, $3). Dollar = dollarFormat{} // Colon is a PlaceholderFormat instance that replaces placeholders with // colon-prefixed positional placeholders (e.g. :1, :2, :3). Colon = colonFormat{} // AtP is a PlaceholderFormat instance that replaces placeholders with // "@p"-prefixed positional placeholders (e.g. @p1, @p2, @p3). AtP = atpFormat{} ) type questionFormat struct{} func (questionFormat) ReplacePlaceholders(sql string) (string, error) { return sql, nil } func (questionFormat) debugPlaceholder() string { return "?" } type dollarFormat struct{} func (dollarFormat) ReplacePlaceholders(sql string) (string, error) { return replacePositionalPlaceholders(sql, "$") } func (dollarFormat) debugPlaceholder() string { return "$" } type colonFormat struct{} func (colonFormat) ReplacePlaceholders(sql string) (string, error) { return replacePositionalPlaceholders(sql, ":") } func (colonFormat) debugPlaceholder() string { return ":" } type atpFormat struct{} func (atpFormat) ReplacePlaceholders(sql string) (string, error) { return replacePositionalPlaceholders(sql, "@p") } func (atpFormat) debugPlaceholder() string { return "@p" } // Placeholders returns a string with count ? placeholders joined with commas. func Placeholders(count int) string { if count < 1 { return "" } return strings.Repeat(",?", count)[1:] } func replacePositionalPlaceholders(sql, prefix string) (string, error) { buf := &bytes.Buffer{} i := 0 for { p := strings.Index(sql, "?") if p == -1 { break } if len(sql[p:]) > 1 && sql[p:p+2] == "??" { // escape ?? => ? buf.WriteString(sql[:p]) buf.WriteString("?") if len(sql[p:]) == 1 { break } sql = sql[p+2:] } else { i++ buf.WriteString(sql[:p]) fmt.Fprintf(buf, "%s%d", prefix, i) sql = sql[p+1:] } } buf.WriteString(sql) return buf.String(), nil }
package squirrel import ( "bytes" "fmt" "strings" "golang.org/x/exp/slices" "github.com/lann/builder" ) // Direction is used in OrderByDir to specify the direction of the ordering. type Direction int const ( Asc Direction = iota Desc ) type PaginatorType int const ( PaginatorTypeUndefined PaginatorType = iota PaginatorTypeByPage PaginatorTypeByID ) // Paginator is a helper object to paginate results. type Paginator struct { limit uint64 page uint64 lastID int64 pType PaginatorType } // PaginatorByPage creates a new Paginator for pagination by page. func PaginatorByPage(pageSize, pageNum uint64) Paginator { return Paginator{ limit: pageSize, page: pageNum, pType: PaginatorTypeByPage, } } // PaginatorByID creates a new Paginator for pagination by ID. func PaginatorByID(limit uint64, lastID int64) Paginator { return Paginator{ limit: limit, lastID: lastID, pType: PaginatorTypeByID, } } // PageSize returns the page size for PaginatorTypeByPage func (p Paginator) PageSize() uint64 { return p.limit } // PageNumber returns the page number for PaginatorTypeByPage func (p Paginator) PageNumber() uint64 { return p.page } // Limit returns the limit for PaginatorTypeByID func (p Paginator) Limit() uint64 { return p.limit } // LastID returns the last ID for PaginatorTypeByID func (p Paginator) LastID() int64 { return p.lastID } // Type returns the type of the paginator. func (p Paginator) Type() PaginatorType { return p.pType } // String returns the string representation of the direction. func (d Direction) String() string { if d == Asc { return "ASC" } return "DESC" } // OrderCond is used in OrderByDir to specify the condition of the ordering. type OrderCond struct { ColumnID int Direction Direction } type selectData struct { PlaceholderFormat PlaceholderFormat Prefixes []Sqlizer Options []string Columns []Sqlizer From Sqlizer Joins []Sqlizer WhereParts []Sqlizer GroupBys []string HavingParts []Sqlizer OrderByParts []Sqlizer Limit string Offset string Suffixes []Sqlizer Paginator Paginator IDColumn string // ID column name. Required for pagination by ID. } func (d *selectData) ToSql() (sqlStr string, args []any, err error) { sqlStr, args, err = d.toSqlRaw() if err != nil { return } sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sqlStr) return } func (d *selectData) toSqlRaw() (sqlStr string, args []any, err error) { if len(d.Columns) == 0 { err = fmt.Errorf("select statements must have at least one result column") return "", nil, err } sql := &bytes.Buffer{} if len(d.Prefixes) > 0 { args, err = appendToSql(d.Prefixes, sql, " ", args) if err != nil { return "", nil, err } _, _ = sql.WriteString(" ") } _, _ = sql.WriteString("SELECT ") if len(d.Options) > 0 { _, _ = sql.WriteString(strings.Join(d.Options, " ")) _, _ = sql.WriteString(" ") } if len(d.Columns) > 0 { args, err = appendToSql(d.Columns, sql, ", ", args) if err != nil { return "", nil, err } } if d.From != nil { _, _ = sql.WriteString(" FROM ") args, err = appendToSql([]Sqlizer{d.From}, sql, "", args) if err != nil { return "", nil, err } } if len(d.Joins) > 0 { _, _ = sql.WriteString(" ") args, err = appendToSql(d.Joins, sql, " ", args) if err != nil { return "", nil, err } } whereParts := make([]Sqlizer, len(d.WhereParts)) copy(whereParts, d.WhereParts) if d.Paginator.pType == PaginatorTypeByID { if d.IDColumn == "" { return "", nil, fmt.Errorf("IDColumn is required for pagination by ID") } whereParts = append(whereParts, Gt{d.IDColumn: d.Paginator.lastID}) } if len(whereParts) > 0 { _, _ = sql.WriteString(" WHERE ") args, err = appendToSql(whereParts, sql, " AND ", args) if err != nil { return "", nil, err } } if len(d.GroupBys) > 0 { _, _ = sql.WriteString(" GROUP BY ") _, _ = sql.WriteString(strings.Join(d.GroupBys, ", ")) } if len(d.HavingParts) > 0 { _, _ = sql.WriteString(" HAVING ") args, err = appendToSql(d.HavingParts, sql, " AND ", args) if err != nil { return "", nil, err } } if len(d.OrderByParts) > 0 { _, _ = sql.WriteString(" ORDER BY ") args, err = appendToSql(d.OrderByParts, sql, ", ", args) if err != nil { return "", nil, err } } if len(d.Limit) > 0 { if d.Paginator.pType != PaginatorTypeUndefined { return "", nil, fmt.Errorf("limit and paginator cannot be used together") } _, _ = sql.WriteString(" LIMIT ") _, _ = sql.WriteString(d.Limit) } if len(d.Offset) > 0 { if d.Paginator.pType != PaginatorTypeUndefined { return "", nil, fmt.Errorf("offset and paginator cannot be used together") } _, _ = sql.WriteString(" OFFSET ") _, _ = sql.WriteString(d.Offset) } if d.Paginator.pType == PaginatorTypeByPage { _, _ = sql.WriteString(fmt.Sprintf(" LIMIT %d", d.Paginator.limit)) if d.Paginator.page > 1 { _, _ = sql.WriteString(fmt.Sprintf(" OFFSET %d", d.Paginator.limit*(d.Paginator.page-1))) } } else if d.Paginator.pType == PaginatorTypeByID { _, _ = sql.WriteString(fmt.Sprintf(" LIMIT %d", d.Paginator.limit)) } if len(d.Suffixes) > 0 { _, _ = sql.WriteString(" ") args, err = appendToSql(d.Suffixes, sql, " ", args) if err != nil { return "", nil, err } } sqlStr = sql.String() return sqlStr, args, nil } // Builder // SelectBuilder builds SQL SELECT statements. type SelectBuilder builder.Builder func init() { builder.Register(SelectBuilder{}, selectData{}) } // Format methods // PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the // query. func (b SelectBuilder) PlaceholderFormat(f PlaceholderFormat) SelectBuilder { return builder.Set(b, "PlaceholderFormat", f).(SelectBuilder) } // SQL methods // ToSql builds the query into a SQL string and bound args. func (b SelectBuilder) ToSql() (string, []any, error) { data := builder.GetStruct(b).(selectData) return data.ToSql() } func (b SelectBuilder) toSqlRaw() (string, []any, error) { data := builder.GetStruct(b).(selectData) return data.toSqlRaw() } // MustSql builds the query into a SQL string and bound args. // It panics if there are any errors. func (b SelectBuilder) MustSql() (string, []any) { sql, args, err := b.ToSql() if err != nil { panic(err) } return sql, args } // Prefix adds an expression to the beginning of the query func (b SelectBuilder) Prefix(sql string, args ...any) SelectBuilder { return b.PrefixExpr(Expr(sql, args...)) } // PrefixExpr adds an expression to the very beginning of the query func (b SelectBuilder) PrefixExpr(e Sqlizer) SelectBuilder { return builder.Append(b, "Prefixes", e).(SelectBuilder) } // Distinct adds a DISTINCT clause to the query. func (b SelectBuilder) Distinct() SelectBuilder { return b.Options("DISTINCT") } // Options adds select option to the query func (b SelectBuilder) Options(options ...string) SelectBuilder { return builder.Extend(b, "Options", options).(SelectBuilder) } // Columns adds result columns to the query. func (b SelectBuilder) Columns(columns ...string) SelectBuilder { parts := make([]any, 0, len(columns)) for _, str := range columns { parts = append(parts, newPart(str)) } return builder.Extend(b, "Columns", parts).(SelectBuilder) } // RemoveColumns remove all columns from query. // Must add a new column with Column or Columns methods, otherwise // return a error. func (b SelectBuilder) RemoveColumns() SelectBuilder { return builder.Delete(b, "Columns").(SelectBuilder) } // Column adds a result column to the query. // Unlike Columns, Column accepts args which will be bound to placeholders in // the columns string, for example: // // Column("IF(col IN ("+squirrel.Placeholders(3)+"), 1, 0) as col", 1, 2, 3) func (b SelectBuilder) Column(column any, args ...any) SelectBuilder { return builder.Append(b, "Columns", newPart(column, args...)).(SelectBuilder) } // From sets the FROM clause of the query. func (b SelectBuilder) From(from string) SelectBuilder { return builder.Set(b, "From", newPart(from)).(SelectBuilder) } // FromSelect sets a subquery into the FROM clause of the query. func (b SelectBuilder) FromSelect(from SelectBuilder, alias string) SelectBuilder { // Prevent misnumbered parameters in nested selects (#183). from = from.PlaceholderFormat(Question) return builder.Set(b, "From", Alias(from, alias)).(SelectBuilder) } // JoinClause adds a join clause to the query. func (b SelectBuilder) JoinClause(pred any, args ...any) SelectBuilder { return builder.Append(b, "Joins", newPart(pred, args...)).(SelectBuilder) } // Join adds a JOIN clause to the query. func (b SelectBuilder) Join(join string, rest ...any) SelectBuilder { return b.JoinClause("JOIN "+join, rest...) } // LeftJoin adds a LEFT JOIN clause to the query. func (b SelectBuilder) LeftJoin(join string, rest ...any) SelectBuilder { return b.JoinClause("LEFT JOIN "+join, rest...) } // RightJoin adds a RIGHT JOIN clause to the query. func (b SelectBuilder) RightJoin(join string, rest ...any) SelectBuilder { return b.JoinClause("RIGHT JOIN "+join, rest...) } // InnerJoin adds a INNER JOIN clause to the query. func (b SelectBuilder) InnerJoin(join string, rest ...any) SelectBuilder { return b.JoinClause("INNER JOIN "+join, rest...) } // CrossJoin adds a CROSS JOIN clause to the query. func (b SelectBuilder) CrossJoin(join string, rest ...any) SelectBuilder { return b.JoinClause("CROSS JOIN "+join, rest...) } // Where adds an expression to the WHERE clause of the query. // // Expressions are ANDed together in the generated SQL. // // Where accepts several types for its pred argument: // // nil OR "" - ignored. // // string - SQL expression. // If the expression has SQL placeholders then a set of arguments must be passed // as well, one for each placeholder. // // map[string]any OR Eq - map of SQL expressions to values. Each key is // transformed into an expression like "<key> = ?", with the corresponding value // bound to the placeholder. If the value is nil, the expression will be "<key> // IS NULL". If the value is an array or slice, the expression will be "<key> IN // (?,?,...)", with one placeholder for each item in the value. These expressions // are ANDed together. // // Where will panic if pred isn't any of the above types. func (b SelectBuilder) Where(pred any, args ...any) SelectBuilder { if pred == nil || pred == "" { return b } return builder.Append(b, "WhereParts", newWherePart(pred, args...)).(SelectBuilder) } // GroupBy adds GROUP BY expressions to the query. func (b SelectBuilder) GroupBy(groupBys ...string) SelectBuilder { return builder.Extend(b, "GroupBys", groupBys).(SelectBuilder) } // Having adds an expression to the HAVING clause of the query. // // See Where. func (b SelectBuilder) Having(pred any, rest ...any) SelectBuilder { return builder.Append(b, "HavingParts", newWherePart(pred, rest...)).(SelectBuilder) } // OrderByClause adds ORDER BY clause to the query. func (b SelectBuilder) OrderByClause(pred any, args ...any) SelectBuilder { return builder.Append(b, "OrderByParts", newPart(pred, args...)).(SelectBuilder) } // OrderBy adds ORDER BY expressions to the query. func (b SelectBuilder) OrderBy(orderBys ...string) SelectBuilder { for _, orderBy := range orderBys { b = b.OrderByClause(orderBy) } return b } // OrderNullsType is used to specify the order of NULLs in ORDER BY clause. type OrderNullsType int const ( OrderNullsUndefined OrderNullsType = iota OrderNullsFirst // ORDER BY ... NULLS FIRST OrderNullsLast // ORDER BY ... NULLS LAST ) // String returns the string representation of the order of NULLs. func (o OrderNullsType) String() string { if o == OrderNullsFirst { return "FIRST" } if o == OrderNullsLast { return "LAST" } return "" } // OrderByCondOption is used to specify additional options for OrderByCond. type OrderByCondOption struct { ColumnID int NullsType OrderNullsType } // OrderByCond adds ORDER BY expressions with direction to the query. // The columns map is used to map OrderCond.ColumnID to the column name. // Can be used to avoid hardcoding column names in the code. func (b SelectBuilder) OrderByCond(columns map[int]string, conds []OrderCond, opts ...OrderByCondOption) SelectBuilder { for i, cond := range conds { if pos := slices.IndexFunc(conds[:i], func(c OrderCond) bool { return c.ColumnID == cond.ColumnID }); pos >= 0 && pos < i { continue } column, ok := columns[cond.ColumnID] if !ok { panic(fmt.Sprintf("column id %d not found in columns map %v", cond.ColumnID, columns)) } nullsType := OrderNullsUndefined for _, opt := range opts { if opt.ColumnID == cond.ColumnID { nullsType = opt.NullsType break } } if nullsType == OrderNullsUndefined { b = b.OrderByClause(fmt.Sprintf("%s %s", column, cond.Direction.String())) } else { b = b.OrderByClause(fmt.Sprintf("%s %s NULLS %s", column, cond.Direction.String(), nullsType.String())) } } return b } // Search adds a search condition to the query. // The search condition is a WHERE clause with LIKE expressions. All columns will be converted to text. // value can be a string or a number. func (b SelectBuilder) Search(value any, columns ...string) SelectBuilder { if len(columns) == 0 { return b } search := Or{} for _, column := range columns { search = append(search, Like{column + "::text": fmt.Sprintf("%%%v%%", value)}) } return b.Where(search) } // PaginateByID adds a LIMIT and start from ID condition to the query. // WARNING: The columnID must be included in the ORDER BY clause to avoid unexpected results! func (b SelectBuilder) PaginateByID(limit uint64, startID int64, columnID string) SelectBuilder { return b.Limit(limit).Where(Gt{columnID: startID}) } // PaginateByPage adds a LIMIT and OFFSET condition to the query. // WARNING: query must be ordered to avoid unexpected results! func (b SelectBuilder) PaginateByPage(limit uint64, page uint64) SelectBuilder { sb := b.Limit(limit) if page > 1 { sb = sb.Offset(limit * (page - 1)) } return sb } // Paginate adds pagination conditions to the query. func (b SelectBuilder) Paginate(p Paginator) SelectBuilder { return builder.Set(b, "Paginator", p).(SelectBuilder) } // SetIDColumn sets the column name to be used for pagination by ID. // Required in special cases when Paginate function combined with PaginatorByID. func (b SelectBuilder) SetIDColumn(column string) SelectBuilder { return builder.Set(b, "IDColumn", column).(SelectBuilder) } // Limit sets a LIMIT clause on the query. func (b SelectBuilder) Limit(limit uint64) SelectBuilder { return builder.Set(b, "Limit", fmt.Sprintf("%d", limit)).(SelectBuilder) } // RemoveLimit Limit ALL allows to access all records with limit func (b SelectBuilder) RemoveLimit() SelectBuilder { return builder.Delete(b, "Limit").(SelectBuilder) } // Offset sets a OFFSET clause on the query. func (b SelectBuilder) Offset(offset uint64) SelectBuilder { return builder.Set(b, "Offset", fmt.Sprintf("%d", offset)).(SelectBuilder) } // RemoveOffset removes OFFSET clause. func (b SelectBuilder) RemoveOffset() SelectBuilder { return builder.Delete(b, "Offset").(SelectBuilder) } // Suffix adds an expression to the end of the query func (b SelectBuilder) Suffix(sql string, args ...any) SelectBuilder { return b.SuffixExpr(Expr(sql, args...)) } // SuffixExpr adds an expression to the end of the query func (b SelectBuilder) SuffixExpr(e Sqlizer) SelectBuilder { return builder.Append(b, "Suffixes", e).(SelectBuilder) } type alias struct { builder SelectBuilder table string prefix []string } // Columns sets the columns for the table alias. func (a alias) Columns(columns ...string) SelectBuilder { if len(columns) == 0 { return a.builder } return a.builder.Columns(prepareAliasColumns(a.table, a.prefix, columns...)...) } // GroupBy sets the group by for the table alias. func (a alias) GroupBy(groupBys ...string) SelectBuilder { if len(groupBys) == 0 { return a.builder } return a.builder.GroupBy(prepareAliasColumns(a.table, a.prefix, groupBys...)...) } // OrderBy sets the order by for the table alias. func (a alias) OrderBy(orderBys ...string) SelectBuilder { if len(orderBys) == 0 { return a.builder } return a.builder.OrderBy(prepareAliasColumns(a.table, a.prefix, orderBys...)...) } func prepareAliasColumns(table string, prefix []string, columns ...string) []string { columnsPrepared := make([]string, 0, len(columns)) for _, column := range columns { if len(prefix) == 0 { if table == "" { columnsPrepared = append(columnsPrepared, column) } else { columnsPrepared = append(columnsPrepared, fmt.Sprintf("%s.%s", table, column)) } } else { if table == "" { columnsPrepared = append(columnsPrepared, fmt.Sprintf("%s AS %s_%s", column, prefix[0], column)) } else { columnsPrepared = append(columnsPrepared, fmt.Sprintf("%s.%s AS %s_%s", table, column, prefix[0], column)) } } } return columnsPrepared } // Alias creates a new table alias for the select builder. // Prefix is used to add a prefix to the beginning of the column names. If no prefix, the column name will be used. // All prefixes except the first will be ignored. func (b SelectBuilder) Alias(table string, prefix ...string) alias { return alias{ builder: b, table: table, prefix: prefix, } } // With adds a CTE (Common Table Expression) to the query. func (b SelectBuilder) With(cteName string, cte SelectBuilder) SelectBuilder { return b.PrefixExpr(cte.Prefix(fmt.Sprintf("WITH %s AS (", cteName)).Suffix(")")) }
// Package squirrel provides a fluent SQL generator. // // See https://github.com/Masterminds/squirrel for examples. package squirrel import ( "bytes" "fmt" "strings" ) // Sqlizer is the interface that wraps the ToSql method. // // ToSql returns a SQL representation of the Sqlizer, along with a slice of args // as passed to e.g. database/sql.Exec. It can also return an error. type Sqlizer interface { ToSql() (string, []any, error) } // rawSqlizer is expected to do what Sqlizer does, but without finalizing placeholders. // This is useful for nested queries. type rawSqlizer interface { toSqlRaw() (string, []any, error) } // DebugSqlizer calls ToSql on s and shows the approximate SQL to be executed // // If ToSql returns an error, the result of this method will look like: // "[ToSql error: %s]" or "[DebugSqlizer error: %s]" // // IMPORTANT: As its name suggests, this function should only be used for // debugging. While the string result *might* be valid SQL, this function does // not try very hard to ensure it. Additionally, executing the output of this // function with any untrusted user input is certainly insecure. func DebugSqlizer(s Sqlizer) string { sql, args, err := s.ToSql() if err != nil { return fmt.Sprintf("[ToSql error: %s]", err) } var placeholder string downCast, ok := s.(placeholderDebugger) if !ok { placeholder = "?" } else { placeholder = downCast.debugPlaceholder() } // TODO: dedupe this with placeholder.go buf := &bytes.Buffer{} i := 0 for { p := strings.Index(sql, placeholder) if p == -1 { break } if len(sql[p:]) > 1 && sql[p:p+2] == "??" { // escape ?? => ? buf.WriteString(sql[:p]) buf.WriteString("?") if len(sql[p:]) == 1 { break } sql = sql[p+2:] } else { if i+1 > len(args) { return fmt.Sprintf( "[DebugSqlizer error: too many placeholders in %#v for %d args]", sql, len(args)) } buf.WriteString(sql[:p]) fmt.Fprintf(buf, "'%v'", args[i]) // advance our sql string "cursor" beyond the arg we placed sql = sql[p+1:] i++ } } if i < len(args) { return fmt.Sprintf( "[DebugSqlizer error: not enough placeholders in %#v for %d args]", sql, len(args)) } // "append" any remaning sql that won't need interpolating buf.WriteString(sql) return buf.String() }
package squirrel import "github.com/lann/builder" // StatementBuilderType is the type of StatementBuilder. type StatementBuilderType builder.Builder // Select returns a SelectBuilder for this StatementBuilderType. func (b StatementBuilderType) Select(columns ...string) SelectBuilder { return SelectBuilder(b).Columns(columns...) } // Insert returns a InsertBuilder for this StatementBuilderType. func (b StatementBuilderType) Insert(into string) InsertBuilder { return InsertBuilder(b).Into(into) } // Replace returns a InsertBuilder for this StatementBuilderType with the // statement keyword set to "REPLACE". func (b StatementBuilderType) Replace(into string) InsertBuilder { return InsertBuilder(b).statementKeyword("REPLACE").Into(into) } // Update returns a UpdateBuilder for this StatementBuilderType. func (b StatementBuilderType) Update(table string) UpdateBuilder { return UpdateBuilder(b).Table(table) } // Delete returns a DeleteBuilder for this StatementBuilderType. func (b StatementBuilderType) Delete(from string) DeleteBuilder { return DeleteBuilder(b).From(from) } // With returns a CommonTableExpressionsBuilder for this StatementBuilderType func (b StatementBuilderType) With(cte string) CommonTableExpressionsBuilder { return CommonTableExpressionsBuilder(b).Cte(cte) } // PlaceholderFormat sets the PlaceholderFormat field for any child builders. func (b StatementBuilderType) PlaceholderFormat(f PlaceholderFormat) StatementBuilderType { return builder.Set(b, "PlaceholderFormat", f).(StatementBuilderType) } // Where adds WHERE expressions to the query. // // See SelectBuilder.Where for more information. func (b StatementBuilderType) Where(pred any, args ...any) StatementBuilderType { return builder.Append(b, "WhereParts", newWherePart(pred, args...)).(StatementBuilderType) } // StatementBuilder is a parent builder for other builders, e.g. SelectBuilder. var StatementBuilder = StatementBuilderType(builder.EmptyBuilder).PlaceholderFormat(Question) // Select returns a new SelectBuilder, optionally setting some result columns. // // See SelectBuilder.Columns. func Select(columns ...string) SelectBuilder { return StatementBuilder.Select(columns...) } // Insert returns a new InsertBuilder with the given table name. // // See InsertBuilder.Into. func Insert(into string) InsertBuilder { return StatementBuilder.Insert(into) } // Replace returns a new InsertBuilder with the statement keyword set to // "REPLACE" and with the given table name. // // See InsertBuilder.Into. func Replace(into string) InsertBuilder { return StatementBuilder.Replace(into) } // Update returns a new UpdateBuilder with the given table name. // // See UpdateBuilder.Table. func Update(table string) UpdateBuilder { return StatementBuilder.Update(table) } // Delete returns a new DeleteBuilder with the given table name. // // See DeleteBuilder.Table. func Delete(from string) DeleteBuilder { return StatementBuilder.Delete(from) } // With returns a new CommonTableExpressionsBuilder with the given first cte name // // See CommonTableExpressionsBuilder.Cte func With(cte string) CommonTableExpressionsBuilder { return StatementBuilder.With(cte) } // WithRecursive returns a new CommonTableExpressionsBuilder with the RECURSIVE option and the given first cte name // // See CommonTableExpressionsBuilder.Cte, CommonTableExpressionsBuilder.Recursive func WithRecursive(cte string) CommonTableExpressionsBuilder { return StatementBuilder.With(cte).Recursive(true) } // Case returns a new CaseBuilder // "what" represents case value func Case(what ...any) CaseBuilder { b := CaseBuilder(builder.EmptyBuilder) switch len(what) { case 0: case 1: b = b.what(what[0]) default: b = b.what(newPart(what[0], what[1:]...)) } return b }
package squirrel import ( "bytes" "fmt" "sort" "strings" "github.com/lann/builder" ) type updateData struct { PlaceholderFormat PlaceholderFormat Prefixes []Sqlizer Table string SetClauses []setClause From Sqlizer WhereParts []Sqlizer OrderBys []string Limit string Offset string Suffixes []Sqlizer } type setClause struct { column string value any } func (d *updateData) ToSql() (sqlStr string, args []any, err error) { if len(d.Table) == 0 { err = fmt.Errorf("update statements must specify a table") return "", nil, err } if len(d.SetClauses) == 0 { err = fmt.Errorf("update statements must have at least one Set clause") return "", nil, err } sql := &bytes.Buffer{} if len(d.Prefixes) > 0 { args, err = appendToSql(d.Prefixes, sql, " ", args) if err != nil { return "", nil, err } _, _ = sql.WriteString(" ") } _, _ = sql.WriteString("UPDATE ") _, _ = sql.WriteString(d.Table) _, _ = sql.WriteString(" SET ") setSqls := make([]string, len(d.SetClauses)) for i, setClause := range d.SetClauses { var valSql string if vs, ok := setClause.value.(Sqlizer); ok { var ( vsql string vargs []any ) vsql, vargs, err = vs.ToSql() if err != nil { return "", nil, err } if _, ok := vs.(SelectBuilder); ok { valSql = fmt.Sprintf("(%s)", vsql) } else { valSql = vsql } args = append(args, vargs...) } else { valSql = "?" args = append(args, setClause.value) } setSqls[i] = fmt.Sprintf("%s = %s", setClause.column, valSql) } _, _ = sql.WriteString(strings.Join(setSqls, ", ")) if d.From != nil { _, _ = sql.WriteString(" FROM ") args, err = appendToSql([]Sqlizer{d.From}, sql, "", args) if err != nil { return "", nil, err } } if len(d.WhereParts) > 0 { _, _ = sql.WriteString(" WHERE ") args, err = appendToSql(d.WhereParts, sql, " AND ", args) if err != nil { return "", nil, err } } if len(d.OrderBys) > 0 { _, _ = sql.WriteString(" ORDER BY ") _, _ = sql.WriteString(strings.Join(d.OrderBys, ", ")) } if len(d.Limit) > 0 { _, _ = sql.WriteString(" LIMIT ") _, _ = sql.WriteString(d.Limit) } if len(d.Offset) > 0 { _, _ = sql.WriteString(" OFFSET ") _, _ = sql.WriteString(d.Offset) } if len(d.Suffixes) > 0 { _, _ = sql.WriteString(" ") args, err = appendToSql(d.Suffixes, sql, " ", args) if err != nil { return "", nil, err } } sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sql.String()) return sqlStr, args, err } // Builder // UpdateBuilder builds SQL UPDATE statements. type UpdateBuilder builder.Builder func init() { builder.Register(UpdateBuilder{}, updateData{}) } // Format methods // PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the // query. func (b UpdateBuilder) PlaceholderFormat(f PlaceholderFormat) UpdateBuilder { return builder.Set(b, "PlaceholderFormat", f).(UpdateBuilder) } // SQL methods // ToSql builds the query into a SQL string and bound args. func (b UpdateBuilder) ToSql() (string, []any, error) { data := builder.GetStruct(b).(updateData) return data.ToSql() } // MustSql builds the query into a SQL string and bound args. // It panics if there are any errors. func (b UpdateBuilder) MustSql() (string, []any) { sql, args, err := b.ToSql() if err != nil { panic(err) } return sql, args } // Prefix adds an expression to the beginning of the query func (b UpdateBuilder) Prefix(sql string, args ...any) UpdateBuilder { return b.PrefixExpr(Expr(sql, args...)) } // PrefixExpr adds an expression to the very beginning of the query func (b UpdateBuilder) PrefixExpr(e Sqlizer) UpdateBuilder { return builder.Append(b, "Prefixes", e).(UpdateBuilder) } // Table sets the table to be updated. func (b UpdateBuilder) Table(table string) UpdateBuilder { return builder.Set(b, "Table", table).(UpdateBuilder) } // Set adds SET clauses to the query. func (b UpdateBuilder) Set(column string, value any) UpdateBuilder { return builder.Append(b, "SetClauses", setClause{column: column, value: value}).(UpdateBuilder) } // SetMap is a convenience method which calls .Set for each key/value pair in clauses. func (b UpdateBuilder) SetMap(clauses map[string]any) UpdateBuilder { keys := make([]string, len(clauses)) i := 0 for key := range clauses { keys[i] = key i++ } sort.Strings(keys) for _, key := range keys { val := clauses[key] b = b.Set(key, val) } return b } // From adds FROM clause to the query // FROM is valid construct in postgresql only. func (b UpdateBuilder) From(from string) UpdateBuilder { return builder.Set(b, "From", newPart(from)).(UpdateBuilder) } // FromSelect sets a subquery into the FROM clause of the query. func (b UpdateBuilder) FromSelect(from SelectBuilder, alias string) UpdateBuilder { // Prevent misnumbered parameters in nested selects (#183). from = from.PlaceholderFormat(Question) return builder.Set(b, "From", Alias(from, alias)).(UpdateBuilder) } // Where adds WHERE expressions to the query. // // See SelectBuilder.Where for more information. func (b UpdateBuilder) Where(pred any, args ...any) UpdateBuilder { return builder.Append(b, "WhereParts", newWherePart(pred, args...)).(UpdateBuilder) } // OrderBy adds ORDER BY expressions to the query. func (b UpdateBuilder) OrderBy(orderBys ...string) UpdateBuilder { return builder.Extend(b, "OrderBys", orderBys).(UpdateBuilder) } // Limit sets a LIMIT clause on the query. func (b UpdateBuilder) Limit(limit uint64) UpdateBuilder { return builder.Set(b, "Limit", fmt.Sprintf("%d", limit)).(UpdateBuilder) } // Offset sets a OFFSET clause on the query. func (b UpdateBuilder) Offset(offset uint64) UpdateBuilder { return builder.Set(b, "Offset", fmt.Sprintf("%d", offset)).(UpdateBuilder) } // Suffix adds an expression to the end of the query func (b UpdateBuilder) Suffix(sql string, args ...any) UpdateBuilder { return b.SuffixExpr(Expr(sql, args...)) } // SuffixExpr adds an expression to the end of the query func (b UpdateBuilder) SuffixExpr(e Sqlizer) UpdateBuilder { return builder.Append(b, "Suffixes", e).(UpdateBuilder) }
package squirrel import ( "fmt" ) type wherePart part func newWherePart(pred any, args ...any) Sqlizer { return &wherePart{pred: pred, args: args} } func (p wherePart) ToSql() (sql string, args []any, err error) { switch pred := p.pred.(type) { case nil: // no-op case rawSqlizer: return pred.toSqlRaw() case Sqlizer: return pred.ToSql() case map[string]any: return Eq(pred).ToSql() case string: sql = pred args = p.args default: err = fmt.Errorf("expected string-keyed map or string, not %T", pred) } return }