// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package mql
import (
"fmt"
"reflect"
)
// isNil reports if a is nil
func isNil(a any) bool {
if a == nil {
return true
}
switch reflect.TypeOf(a).Kind() {
case reflect.Ptr, reflect.Map, reflect.Chan, reflect.Slice, reflect.Func:
return reflect.ValueOf(a).IsNil()
}
return false
}
// panicIfNil will panic if a is nil
func panicIfNil(a any, caller, missing string) {
if isNil(a) {
panic(fmt.Sprintf("%s: missing %s", caller, missing))
}
}
func pointer[T any](input T) *T {
return &input
}
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package mql
import (
"fmt"
)
type exprType int
const (
unknownExprType exprType = iota
comparisonExprType
logicalExprType
)
type expr interface {
Type() exprType
String() string
}
// ComparisonOp defines a set of comparison operators
type ComparisonOp string
const (
GreaterThanOp ComparisonOp = ">"
GreaterThanOrEqualOp ComparisonOp = ">="
LessThanOp ComparisonOp = "<"
LessThanOrEqualOp ComparisonOp = "<="
EqualOp ComparisonOp = "="
NotEqualOp ComparisonOp = "!="
ContainsOp ComparisonOp = "%"
)
func newComparisonOp(s string) (ComparisonOp, error) {
const op = "newComparisonOp"
switch ComparisonOp(s) {
case
GreaterThanOp,
GreaterThanOrEqualOp,
LessThanOp,
LessThanOrEqualOp,
EqualOp,
NotEqualOp,
ContainsOp:
return ComparisonOp(s), nil
default:
return "", fmt.Errorf("%s: %w %q", op, ErrInvalidComparisonOp, s)
}
}
type comparisonExpr struct {
column string
comparisonOp ComparisonOp
value *string
}
// Type returns the expr type
func (e *comparisonExpr) Type() exprType {
return comparisonExprType
}
// String returns a string rep of the expr
func (e *comparisonExpr) String() string {
switch e.value {
case nil:
return fmt.Sprintf("(comparisonExpr: %s %s nil)", e.column, e.comparisonOp)
default:
return fmt.Sprintf("(comparisonExpr: %s %s %s)", e.column, e.comparisonOp, *e.value)
}
}
func (e *comparisonExpr) isComplete() bool {
return e.column != "" && e.comparisonOp != "" && e.value != nil
}
// defaultValidateConvert will validate the comparison expr value, and then convert the
// expr to its SQL equivalence.
func defaultValidateConvert(columnName string, comparisonOp ComparisonOp, columnValue *string, validator validator, opt ...Option) (*WhereClause, error) {
const op = "mql.(comparisonExpr).convertToSql"
switch {
case columnName == "":
return nil, fmt.Errorf("%s: %w", op, ErrMissingColumn)
case comparisonOp == "":
return nil, fmt.Errorf("%s: %w", op, ErrMissingComparisonOp)
case isNil(columnValue):
return nil, fmt.Errorf("%s: %w", op, ErrMissingComparisonValue)
case validator.fn == nil:
return nil, fmt.Errorf("%s: missing validator function: %w", op, ErrInvalidParameter)
case validator.typ == "":
return nil, fmt.Errorf("%s: missing validator type: %w", op, ErrInvalidParameter)
}
// everything was validated at the start, so we know this is a valid/complete comparisonExpr
e := &comparisonExpr{
column: columnName,
comparisonOp: comparisonOp,
value: columnValue,
}
v, err := validator.fn(*e.value)
if err != nil {
return nil, fmt.Errorf("%s: %q in %s: %w", op, *e.value, e.String(), ErrInvalidParameter)
}
opts, err := getOpts(opt...)
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
if n, ok := opts.withTableColumnMap[columnName]; ok {
// override our column name with the mapped column name
columnName = n
}
if validator.typ == "time" {
columnName = fmt.Sprintf("%s::date", columnName)
}
switch e.comparisonOp {
case ContainsOp:
return &WhereClause{
Condition: fmt.Sprintf("%s like ?", columnName),
Args: []any{fmt.Sprintf("%%%s%%", v)},
}, nil
default:
return &WhereClause{
Condition: fmt.Sprintf("%s%s?", columnName, e.comparisonOp),
Args: []any{v},
}, nil
}
}
type logicalOp string
const (
andOp logicalOp = "and"
orOp logicalOp = "or"
)
func newLogicalOp(s string) (logicalOp, error) {
const op = "newLogicalOp"
switch logicalOp(s) {
case andOp, orOp:
return logicalOp(s), nil
default:
return "", fmt.Errorf("%s: %w %q", op, ErrInvalidLogicalOp, s)
}
}
type logicalExpr struct {
leftExpr expr
logicalOp logicalOp
rightExpr expr
}
// Type returns the expr type
func (l *logicalExpr) Type() exprType {
return logicalExprType
}
// String returns a string rep of the expr
func (l *logicalExpr) String() string {
return fmt.Sprintf("(logicalExpr: %s %s %s)", l.leftExpr, l.logicalOp, l.rightExpr)
}
// root will return the root of the expr tree
func root(lExpr *logicalExpr, raw string) (expr, error) {
const op = "mql.root"
switch {
// intentionally not checking raw, since can be an empty string
case lExpr == nil:
return nil, fmt.Errorf("%s: %w (missing expression)", op, ErrInvalidParameter)
}
logicalOp := lExpr.logicalOp
if logicalOp != "" && lExpr.rightExpr == nil {
return nil, fmt.Errorf("%s: %w in: %q", op, ErrMissingRightSideExpr, raw)
}
for lExpr.logicalOp == "" {
switch {
case lExpr.leftExpr == nil:
return nil, fmt.Errorf("%s: %w nil in: %q", op, ErrMissingExpr, raw)
case lExpr.leftExpr.Type() == comparisonExprType:
return lExpr.leftExpr, nil
default:
lExpr = lExpr.leftExpr.(*logicalExpr)
}
}
return lExpr, nil
}
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package mql
import (
"bufio"
"bytes"
"fmt"
"strings"
"unicode"
)
// Delimiter used to quote strings
type Delimiter rune
const (
DoubleQuote Delimiter = '"'
SingleQuote Delimiter = '\''
Backtick Delimiter = '`'
backslash = '\\'
)
type lexStateFunc func(*lexer) (lexStateFunc, error)
type lexer struct {
source *bufio.Reader
current stack[rune]
tokens chan token
state lexStateFunc
}
func newLexer(s string) *lexer {
l := &lexer{
source: bufio.NewReader(strings.NewReader(s)),
state: lexStartState,
tokens: make(chan token, 1), // define a ring buffer for emitted tokens
}
return l
}
// nextToken is the external api for the lexer and it simply returns the next
// token or an error. If EOF is encountered while scanning, nextToken will keep
// returning an eofToken no matter how many times you call nextToken.
func (l *lexer) nextToken() (token, error) {
for {
select {
case tk := <-l.tokens: // return a token if one has been emitted
return tk, nil
default: // otherwise, keep scanning via the next state
var err error
if l.state, err = l.state(l); err != nil {
return token{}, err
}
}
}
}
// lexStartState is the start state. It doesn't emit tokens, but rather
// transitions to other states. Other states typically transition back to
// lexStartState after they emit a token.
func lexStartState(l *lexer) (lexStateFunc, error) {
panicIfNil(l, "lexStartState", "lexer")
r := l.read()
switch {
// wait, if it's eof we're done
case r == eof:
l.emit(eofToken, "")
return lexEofState, nil
// start with finding all tokens that can have a trailing "="
case r == '>':
return lexGreaterState, nil
case r == '<':
return lexLesserState, nil
// now, we can just look at the next rune...
case r == '%':
return lexContainsState, nil
case r == '=':
return lexEqualState, nil
case r == '!':
return lexNotEqualState, nil
case r == ')':
return lexRightParenState, nil
case r == '(':
return lexLeftParenState, nil
case isSpace(r):
return lexWhitespaceState, nil
case unicode.IsDigit(r) || r == '.':
l.unread()
return lexNumberState, nil
case isDelimiter(r):
l.unread()
return lexStringState, nil
default:
l.unread()
return lexSymbolState, nil
}
}
// lexStringState scans for strings and can emit a stringToken
func lexStringState(l *lexer) (lexStateFunc, error) {
const op = "mql.lexStringState"
panicIfNil(l, "lexStringState", "lexer")
defer l.current.clear()
// we'll push the runes we read into this buffer and when appropriate will
// emit tokens using the buffer's data.
var tokenBuf bytes.Buffer
// before we start looping, let's found out if we're scanning a quoted string
r := l.read()
delimiter := r
if !isDelimiter(delimiter) {
return nil, fmt.Errorf("%s: %w %q", op, ErrInvalidDelimiter, delimiter)
}
finalDelimiter := false
WriteToBuf:
// keep reading runes into the buffer until we encounter eof or the final delimiter.
for {
r = l.read()
switch {
case r == eof:
break WriteToBuf
case r == backslash:
nextR := l.read()
switch {
case nextR == eof:
tokenBuf.WriteRune(r)
return nil, fmt.Errorf("%s: %w in %q", op, ErrInvalidTrailingBackslash, tokenBuf.String())
case nextR == backslash:
tokenBuf.WriteRune(nextR)
case nextR == delimiter:
tokenBuf.WriteRune(nextR)
default:
tokenBuf.WriteRune(r)
tokenBuf.WriteRune(nextR)
}
case r == delimiter: // end of the quoted string we're scanning
finalDelimiter = true
break WriteToBuf
default: // otherwise, write the rune into the keyword buffer
tokenBuf.WriteRune(r)
}
}
switch {
case !finalDelimiter:
return nil, fmt.Errorf("%s: %w for \"%s", op, ErrMissingEndOfStringTokenDelimiter, tokenBuf.String())
default:
l.emit(stringToken, tokenBuf.String())
return lexStartState, nil
}
}
// lexSymbolState scans for strings and can emit the following tokens:
// orToken, andToken, containsToken
func lexSymbolState(l *lexer) (lexStateFunc, error) {
const op = "mql.lexSymbolState"
panicIfNil(l, "lexSymbolState", "lexer")
defer l.current.clear()
ReadRunes:
// keep reading runes into the buffer until we encounter eof of non-text runes.
for {
r := l.read()
switch {
case r == eof:
break ReadRunes
case (isSpace(r) || isSpecial(r)): // whitespace or a special char
l.unread()
break ReadRunes
default:
continue ReadRunes
}
}
switch strings.ToLower(runesToString(l.current)) {
case "and":
l.emit(andToken, "and")
return lexStartState, nil
case "or":
l.emit(orToken, "or")
return lexStartState, nil
default:
l.emit(symbolToken, runesToString(l.current))
return lexStartState, nil
}
}
func lexNumberState(l *lexer) (lexStateFunc, error) {
const op = "mql.lexNumberState"
defer l.current.clear()
isFloat := false
// we'll push the runes we read into this buffer and when appropriate will
// emit tokens using the buffer's data.
var buf []rune
WriteToBuf:
// keep reading runes into the buffer until we encounter eof of non-number runes.
for {
r := l.read()
switch {
case r == eof:
break WriteToBuf
case r == '.' && isFloat:
buf = append(buf, r)
return nil, fmt.Errorf("%s: %w in %q", op, ErrInvalidNumber, string(buf))
case r == '.' && !isFloat:
isFloat = true
buf = append(buf, r)
case unicode.IsDigit(r) || (r == '.' && len(buf) == 0):
buf = append(buf, r)
default:
l.unread()
break WriteToBuf
}
}
l.emit(numberToken, string(buf))
return lexStartState, nil
}
// lexContainsState emits an containsToken and returns to the lexStartState
func lexContainsState(l *lexer) (lexStateFunc, error) {
panicIfNil(l, "lexContainsState", "lexer")
defer l.current.clear()
l.emit(containsToken, "%")
return lexStartState, nil
}
// lexEqualState emits an equalToken and returns to the lexStartState
func lexEqualState(l *lexer) (lexStateFunc, error) {
panicIfNil(l, "lexEqualState", "lexer")
defer l.current.clear()
l.emit(equalToken, "=")
return lexStartState, nil
}
// lexNotEqualState scans for a notEqualToken and return either to the lexStartState or
// lexErrorState
func lexNotEqualState(l *lexer) (lexStateFunc, error) {
const op = "mql.lexNotEqualState"
panicIfNil(l, "lexNotEqualState", "lexer")
defer l.current.clear()
nextRune := l.read()
switch nextRune {
case '=':
l.emit(notEqualToken, "!=")
return lexStartState, nil
default:
return nil, fmt.Errorf("%s: %w, got %q", op, ErrInvalidNotEqual, fmt.Sprintf("%s%s", "!", string(nextRune)))
}
}
// lexLeftParenState emits a startLogicalExprToken and returns to the
// lexStartState
func lexLeftParenState(l *lexer) (lexStateFunc, error) {
panicIfNil(l, "lexLeftParenState", "lexer")
defer l.current.clear()
l.emit(startLogicalExprToken, runesToString(l.current))
return lexStartState, nil
}
// lexRightParenState emits an endLogicalExprToken and returns to the
// lexStartState
func lexRightParenState(l *lexer) (lexStateFunc, error) {
panicIfNil(l, "lexRightParenState", "lexer")
defer l.current.clear()
l.emit(endLogicalExprToken, runesToString(l.current))
return lexStartState, nil
}
// lexWhitespaceState emits a whitespaceToken and returns to the lexStartState
func lexWhitespaceState(l *lexer) (lexStateFunc, error) {
panicIfNil(l, "lexWhitespaceState", "lexer")
defer l.current.clear()
ReadWhitespace:
for {
ch := l.read()
switch {
case ch == eof:
break ReadWhitespace
case !isSpace(ch):
l.unread()
break ReadWhitespace
}
}
l.emit(whitespaceToken, "")
return lexStartState, nil
}
// lexGreaterState will emit either a greaterThanToken or a
// greaterThanOrEqualToken and return to the lexStartState
func lexGreaterState(l *lexer) (lexStateFunc, error) {
panicIfNil(l, "lexGreaterState", "lexer")
defer l.current.clear()
next := l.read()
switch next {
case '=':
l.emit(greaterThanOrEqualToken, ">=")
return lexStartState, nil
default:
l.unread()
l.emit(greaterThanToken, ">")
return lexStartState, nil
}
}
// lexLesserState will emit either a lessThanToken or a lessThanOrEqualToken and
// return to the lexStartState
func lexLesserState(l *lexer) (lexStateFunc, error) {
panicIfNil(l, "lexLesserState", "lexer")
defer l.current.clear()
next := l.read()
switch next {
case '=':
l.emit(lessThanOrEqualToken, "<=")
return lexStartState, nil
default:
l.unread()
l.emit(lessThanToken, "<")
return lexStartState, nil
}
}
// lexEofState will emit an eofToken and returns right back to the lexEofState
func lexEofState(l *lexer) (lexStateFunc, error) {
panicIfNil(l, "lexEofState", "lexer")
l.emit(eofToken, "")
return lexEofState, nil
}
// emit send a token to the lexer's token channel
func (l *lexer) emit(t tokenType, v string) {
l.tokens <- token{
Type: t,
Value: v,
}
}
// isSpace reports if r is a space
func isSpace(r rune) bool {
return r == ' ' || r == '\t' || r == '\r' || r == '\n'
}
// isSpecial reports r is special rune
func isSpecial(r rune) bool {
return r == '=' || r == '>' || r == '!' || r == '<' || r == '(' || r == ')' || r == '%'
}
// read the next rune
func (l *lexer) read() rune {
ch, _, err := l.source.ReadRune()
if err != nil {
return eof
}
l.current.push(ch)
return ch
}
// unread the last rune read which means that rune will be returned the next
// time lexer.read() is called. unread also removes the last rune from the
// lexer's stack of current runes
func (l *lexer) unread() {
_ = l.source.UnreadRune() // error ignore which only occurs when nothing has been previously read
_, _ = l.current.pop()
}
func isDelimiter(r rune) bool {
switch Delimiter(r) {
case DoubleQuote, SingleQuote, Backtick:
return true
default:
return false
}
}
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package mql
import (
"fmt"
"reflect"
"strings"
)
// WhereClause contains a SQL where clause condition and its arguments.
type WhereClause struct {
// Condition is the where clause condition
Condition string
// Args for the where clause condition
Args []any
}
// Parse will parse the query and use the provided database model to create a
// where clause. Supported options: WithColumnMap, WithIgnoreFields,
// WithConverter, WithPgPlaceholder
func Parse(query string, model any, opt ...Option) (*WhereClause, error) {
const op = "mql.Parse"
switch {
case query == "":
return nil, fmt.Errorf("%s: missing query: %w", op, ErrInvalidParameter)
case isNil(model):
return nil, fmt.Errorf("%s: missing model: %w", op, ErrInvalidParameter)
}
p := newParser(query)
expr, err := p.parse()
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
fValidators, err := fieldValidators(reflect.ValueOf(model), opt...)
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
e, err := exprToWhereClause(expr, fValidators, opt...)
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
opts, err := getOpts(opt...)
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
if opts.withPgPlaceholder {
for i := 0; i < len(e.Args); i++ {
placeholder := fmt.Sprintf("$%d", i+1)
e.Condition = strings.Replace(e.Condition, "?", placeholder, 1)
}
}
return e, nil
}
// exprToWhereClause generates the where clause condition along with its
// required arguments. Supported options: WithColumnMap, WithConverter
func exprToWhereClause(e expr, fValidators map[string]validator, opt ...Option) (*WhereClause, error) {
const op = "mql.exprToWhereClause"
switch {
case isNil(e):
return nil, fmt.Errorf("%s: missing expression: %w", op, ErrInvalidParameter)
case isNil(fValidators):
return nil, fmt.Errorf("%s: missing validators: %w", op, ErrInvalidParameter)
}
switch v := e.(type) {
case *comparisonExpr:
opts, err := getOpts(opt...)
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
switch validateConvertFn, ok := opts.withValidateConvertFns[v.column]; {
case ok && !isNil(validateConvertFn):
return validateConvertFn(v.column, v.comparisonOp, v.value)
default:
var ok bool
var validator validator
columnName := v.column
switch {
case opts.withColumnFieldTag != "":
validator, ok = fValidators[columnName]
default:
columnName = strings.ToLower(v.column)
if n, ok := opts.withColumnMap[columnName]; ok {
columnName = n
}
validator, ok = fValidators[strings.ToLower(strings.ReplaceAll(columnName, "_", ""))]
}
if !ok {
cols := make([]string, len(fValidators))
for c := range fValidators {
cols = append(cols, c)
}
return nil, fmt.Errorf("%s: %w %q %s", op, ErrInvalidColumn, columnName, cols)
}
w, err := defaultValidateConvert(columnName, v.comparisonOp, v.value, validator, opt...)
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
return w, nil
}
case *logicalExpr:
left, err := exprToWhereClause(v.leftExpr, fValidators, opt...)
if err != nil {
return nil, fmt.Errorf("%s: invalid left expr: %w", op, err)
}
if v.logicalOp == "" {
return nil, fmt.Errorf("%s: %w that stated with left expr condition: %q args: %q", op, ErrMissingLogicalOp, left.Condition, left.Args)
}
right, err := exprToWhereClause(v.rightExpr, fValidators, opt...)
if err != nil {
return nil, fmt.Errorf("%s: invalid right expr: %w", op, err)
}
return &WhereClause{
Condition: fmt.Sprintf("(%s %s %s)", left.Condition, v.logicalOp, right.Condition),
Args: append(left.Args, right.Args...),
}, nil
default:
return nil, fmt.Errorf("%s: unexpected expr type %T: %w", op, v, ErrInternal)
}
}
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package mql
import (
"fmt"
)
type options struct {
withSkipWhitespace bool
withColumnMap map[string]string
withColumnFieldTag string
withValidateConvertFns map[string]ValidateConvertFunc
withIgnoredFields []string
withPgPlaceholder bool
withTableColumnMap map[string]string // map of model field names to their table.column name
}
// Option - how options are passed as args
type Option func(*options) error
func getDefaultOptions() options {
return options{
withColumnMap: make(map[string]string),
withColumnFieldTag: "",
withValidateConvertFns: make(map[string]ValidateConvertFunc),
withTableColumnMap: make(map[string]string),
}
}
func getOpts(opt ...Option) (options, error) {
opts := getDefaultOptions()
for _, o := range opt {
if err := o(&opts); err != nil {
return opts, err
}
}
return opts, nil
}
// withSkipWhitespace provides an option to request that whitespace be skipped
func withSkipWhitespace() Option {
return func(o *options) error {
o.withSkipWhitespace = true
return nil
}
}
// WithColumnMap provides an optional map of columns from the user
// provided query to a field in the given model
func WithColumnMap(m map[string]string) Option {
const op = "mql.WithColumnMap"
return func(o *options) error {
if !isNil(m) {
if o.withColumnFieldTag != "" {
return fmt.Errorf("%s: cannot be used with WithColumnFieldTag: %w", op, ErrInvalidParameter)
}
o.withColumnMap = m
}
return nil
}
}
// WithColumnFieldTag provides an optional struct tag to use for field mapping
// If a field has this tag, the tag value will be used instead of the field name
func WithColumnFieldTag(tagName string) Option {
const op = "mql.WithColumnFieldTag"
return func(o *options) error {
if tagName == "" {
return fmt.Errorf("%s: empty tag name: %w", op, ErrInvalidParameter)
}
if len(o.withColumnMap) > 0 {
return fmt.Errorf("%s: cannot be used with WithColumnMap: %w", op, ErrInvalidParameter)
}
o.withColumnFieldTag = tagName
return nil
}
}
// ValidateConvertFunc validates the value and then converts the columnName,
// comparisonOp and value to a WhereClause
type ValidateConvertFunc func(columnName string, comparisonOp ComparisonOp, value *string) (*WhereClause, error)
// WithConverter provides an optional ConvertFunc for a column identifier in the
// query. This allows you to provide whatever custom validation+conversion you
// need on a per column basis. See: DefaultValidateConvert(...) for inspiration.
func WithConverter(fieldName string, fn ValidateConvertFunc) Option {
const op = "mql.WithSqlConverter"
return func(o *options) error {
switch {
case fieldName != "" && !isNil(fn):
if _, exists := o.withValidateConvertFns[fieldName]; exists {
return fmt.Errorf("%s: duplicated convert: %w", op, ErrInvalidParameter)
}
o.withValidateConvertFns[fieldName] = fn
case fieldName == "" && !isNil(fn):
return fmt.Errorf("%s: missing field name: %w", op, ErrInvalidParameter)
case fieldName != "" && isNil(fn):
return fmt.Errorf("%s: missing ConvertToSqlFunc: %w", op, ErrInvalidParameter)
}
return nil
}
}
// WithIgnoredFields provides an optional list of fields to ignore in the model
// (your Go struct) when parsing. Note: Field names are case sensitive.
func WithIgnoredFields(fieldName ...string) Option {
return func(o *options) error {
o.withIgnoredFields = fieldName
return nil
}
}
// WithPgPlaceholders will use parameters placeholders that are compatible with
// the postgres pg driver which requires a placeholder like $1 instead of ?.
// See:
// - https://pkg.go.dev/github.com/jackc/pgx/v5
// - https://pkg.go.dev/github.com/lib/pq
func WithPgPlaceholders() Option {
return func(o *options) error {
o.withPgPlaceholder = true
return nil
}
}
// WithTableColumnMap provides an optional map of columns from the
// model to the table.column name in the generated where clause
//
// For example, if you need to map the language field name to something
// more complex in your SQL statement then you can use this map:
//
// WithTableColumnMap(map[string]string{"language":"preferences->>'language'"})
//
// In the example above we're mapping "language" field to a json field in
// the "preferences" column. A user can say `language="blah"` and the
// mql-created SQL where clause will contain `preferences->>'language'="blah"`
//
// The field names in the keys to the map should always be lower case.
func WithTableColumnMap(m map[string]string) Option {
return func(o *options) error {
if !isNil(m) {
o.withTableColumnMap = m
}
return nil
}
}
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package mql
import (
"fmt"
"strings"
"unicode"
)
type parser struct {
l *lexer
raw string
currentToken token
openLogicalExpr stack[struct{}] // something very simple to make sure every logical expr that's opened is closed.
}
func newParser(s string) *parser {
var fixedUp string
{
// remove any leading/trailing whitespace
fixedUp = strings.TrimSpace(s)
// remove any leading space before a right parenthesis (issue #42)
fixedUp = removeSpacesBeforeParen(fixedUp)
}
return &parser{
l: newLexer(fixedUp),
raw: s,
}
}
func (p *parser) parse() (expr, error) {
const op = "mql.(parser).parse"
lExpr, err := p.parseLogicalExpr()
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
r, err := root(lExpr, p.raw)
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
return r, nil
}
// parseLogicalExpr will parse a logicalExpr until an eofToken is reached, which
// may require it to parse a comparisonExpr and/or recursively parse
// logicalExprs
func (p *parser) parseLogicalExpr() (*logicalExpr, error) {
const op = "parseLogicalExpr"
logicExpr := &logicalExpr{}
if err := p.scan(withSkipWhitespace()); err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
TkLoop:
for p.currentToken.Type != eofToken {
switch p.currentToken.Type {
case startLogicalExprToken: // there's a opening paren: (
// so we've found a new logical expr to parse
e, err := p.parseLogicalExpr()
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
switch {
// start by assigning the left expr
case logicExpr.leftExpr == nil:
logicExpr.leftExpr = e
break TkLoop
// we should have a logical operator before the right side expr is assigned
case logicExpr.logicalOp == "":
return nil, fmt.Errorf("%s: %w before right side expression in: %q", op, ErrMissingLogicalOp, p.raw)
// finally, assign the right expr
case logicExpr.rightExpr == nil:
if e.rightExpr != nil {
// if e.rightExpr isn't nil, then we've got a complete
// expr (left + op + right) and we need to assign this to
// our rightExpr
logicExpr.rightExpr = e
break TkLoop
}
// otherwise, we need to assign the left side of e
logicExpr.rightExpr = e.leftExpr
break TkLoop
}
case stringToken, numberToken, symbolToken:
if (logicExpr.leftExpr != nil && logicExpr.logicalOp == "") ||
(logicExpr.leftExpr != nil && logicExpr.rightExpr != nil) {
return nil, fmt.Errorf("%s: %w starting at %q in: %q", op, ErrUnexpectedExpr, p.currentToken.Value, p.raw)
}
cmpExpr, err := p.parseComparisonExpr()
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
switch {
case logicExpr.leftExpr == nil:
logicExpr.leftExpr = cmpExpr
case logicExpr.rightExpr == nil:
logicExpr.rightExpr = cmpExpr
tmpExpr := &logicalExpr{
leftExpr: logicExpr,
logicalOp: "",
rightExpr: nil,
}
logicExpr = tmpExpr
default:
return nil, fmt.Errorf("%s: %w at %q, but both left and right expressions already exist in: %q", op, ErrUnexpectedExpr, p.currentToken.Value, p.raw)
}
case endLogicalExprToken:
if logicExpr.leftExpr == nil {
return nil, fmt.Errorf("%s: %w %q but we haven't parsed a left side expression in: %q", op, ErrUnexpectedClosingParen, p.currentToken.Value, p.raw)
}
return logicExpr, nil
case andToken, orToken:
if logicExpr.logicalOp != "" {
return nil, fmt.Errorf("%s: %w %q when we've already parsed one for expr in: %q", op, ErrUnexpectedLogicalOp, p.currentToken.Value, p.raw)
}
o, err := newLogicalOp(p.currentToken.Value)
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
logicExpr.logicalOp = o
default:
return nil, fmt.Errorf("%s: %w %q in: %q", op, ErrUnexpectedToken, p.currentToken.Value, p.raw)
}
if err := p.scan(withSkipWhitespace()); err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
}
if p.openLogicalExpr.len() > 0 {
return nil, fmt.Errorf("%s: %w in: %q", op, ErrMissingClosingParen, p.raw)
}
return logicExpr, nil
}
// parseComparisonExpr will parse a comparisonExpr until an eofToken is reached,
// which may require it to parse logicalExpr
func (p *parser) parseComparisonExpr() (expr, error) {
const op = "mql.(parser).parseComparisonExpr"
cmpExpr := &comparisonExpr{}
// our language (and this parser) def requires the tokens to be in the
// correct order: column, comparisonOp, value. Swapping this order where the
// value comes first (value, comparisonOp, column) is not supported
for p.currentToken.Type != eofToken {
switch {
case p.currentToken.Type == startLogicalExprToken:
switch {
case cmpExpr.isComplete():
return nil, fmt.Errorf("%s: %w after %s in: %q", op, ErrUnexpectedOpeningParen, cmpExpr, p.raw)
default:
return nil, fmt.Errorf("%s: %w in: %q", op, ErrUnexpectedOpeningParen, p.raw)
}
// we already have a complete comparisonExpr
case cmpExpr.isComplete() &&
(p.currentToken.Type != whitespaceToken && p.currentToken.Type != endLogicalExprToken):
return nil, fmt.Errorf("%s: %w %s:%q in: %s", op, ErrUnexpectedToken, p.currentToken.Type, p.currentToken.Value, p.raw)
// we found whitespace, so check if there's a completed logical expr to return
case p.currentToken.Type == whitespaceToken:
if cmpExpr.column != "" && cmpExpr.comparisonOp != "" && cmpExpr.value != nil {
return cmpExpr, nil
}
// columns must come first, so handle those conditions
case cmpExpr.column == "" && p.currentToken.Type != symbolToken:
// this should be unreachable because parseComparisonExpr(...) is
// called when a symbolToken is the current token, but I've kept
// this case here for completeness
return nil, fmt.Errorf("%s: %w: we expected a %s and got %s == %s in: %q", op, ErrUnexpectedToken, symbolToken, p.currentToken.Type, p.currentToken.Value, p.raw)
case cmpExpr.column == "": // has to be stringToken representing the column
cmpExpr.column = p.currentToken.Value
// after columns, comparison operators must come next
case cmpExpr.comparisonOp == "":
c, err := newComparisonOp(p.currentToken.Value)
if err != nil {
return nil, fmt.Errorf("%s: %w %q in: %q", op, err, p.currentToken.Value, p.raw)
}
cmpExpr.comparisonOp = c
// finally, values must come at the end
case cmpExpr.value == nil && (p.currentToken.Type != stringToken && p.currentToken.Type != numberToken && p.currentToken.Type != symbolToken):
return nil, fmt.Errorf("%s: %w %q in: %q", op, ErrUnexpectedToken, p.currentToken.Value, p.raw)
case cmpExpr.value == nil:
switch {
case p.currentToken.Type == symbolToken:
return nil, fmt.Errorf("%s: %w %s == %s (expected: %s or %s) in %q", op, ErrInvalidComparisonValueType, p.currentToken.Type, p.currentToken.Value, stringToken, numberToken, p.raw)
case p.currentToken.Type == stringToken, p.currentToken.Type == numberToken:
s := p.currentToken.Value
cmpExpr.value = &s
default:
return nil, fmt.Errorf("%s: %w of %s == %s", op, ErrUnexpectedToken, p.currentToken.Type, p.currentToken.Value)
}
}
if err := p.scan(); err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
}
switch {
case cmpExpr.column != "" && cmpExpr.comparisonOp == "":
return nil, fmt.Errorf("%s: %w in: %q", op, ErrMissingComparisonOp, p.raw)
default:
return cmpExpr, nil
}
}
// scan will get the next token from the lexer. Supported options:
// withSkipWhitespace
func (p *parser) scan(opt ...Option) error {
const op = "mql.(parser).scan"
opts, err := getOpts(opt...)
if err != nil {
return fmt.Errorf("%s: %w", op, err)
}
if p.currentToken, err = p.l.nextToken(); err != nil {
return fmt.Errorf("%s: %w", op, err)
}
if opts.withSkipWhitespace {
for p.currentToken.Type == whitespaceToken {
if p.currentToken, err = p.l.nextToken(); err != nil {
return fmt.Errorf("%s: %w", op, err)
}
}
}
switch p.currentToken.Type {
case startLogicalExprToken:
p.openLogicalExpr.push(struct{}{})
case endLogicalExprToken:
p.openLogicalExpr.pop()
}
return nil
}
func removeSpacesBeforeParen(s string) string {
if len(s) == 0 {
return s
}
var result strings.Builder
runes := []rune(s)
i := 0
for i < len(runes) {
if unicode.IsSpace(runes[i]) {
start := i
for i < len(runes) && unicode.IsSpace(runes[i]) {
i++
}
if i < len(runes) && runes[i] == ')' {
result.WriteRune(')')
i++ // move past the ')'
} else {
// Otherwise, the whitespace is not followed by ')', so keep it
result.WriteString(string(runes[start:i]))
}
} else {
// Normal character, just append to result
result.WriteRune(runes[i])
i++
}
}
return result.String()
}
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package mql
type stack[T any] struct {
data []T
}
func (s *stack[T]) push(v T) {
s.data = append(s.data, v)
}
func (s *stack[T]) pop() (T, bool) {
var x T
if len(s.data) > 0 {
x, s.data = s.data[len(s.data)-1], s.data[:len(s.data)-1]
return x, true
}
return x, false
}
func (s *stack[T]) clear() {
s.data = nil
}
func (s *stack[T]) len() int {
return len(s.data)
}
func runesToString(s stack[rune]) string {
return string(s.data)
}
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package mql
type token struct {
Type tokenType
Value string
}
type tokenType int
const eof rune = -1
const (
unknownToken tokenType = iota
eofToken
whitespaceToken
stringToken
startLogicalExprToken
endLogicalExprToken
greaterThanToken
greaterThanOrEqualToken
lessThanToken
lessThanOrEqualToken
equalToken
notEqualToken
containsToken
numberToken
symbolToken
// keywords
andToken
orToken
)
var tokenTypeToString = map[tokenType]string{
unknownToken: "unknown",
eofToken: "eof",
whitespaceToken: "ws",
stringToken: "str",
startLogicalExprToken: "lparen",
endLogicalExprToken: "rparen",
greaterThanToken: "gt",
greaterThanOrEqualToken: "gte",
lessThanToken: "lt",
lessThanOrEqualToken: "lte",
equalToken: "eq",
notEqualToken: "neq",
containsToken: "contains",
andToken: "and",
orToken: "or",
numberToken: "num",
symbolToken: "symbol",
}
// String returns a string of the tokenType and will return "Unknown" for
// invalid tokenTypes
func (t tokenType) String() string {
s, ok := tokenTypeToString[t]
switch ok {
case true:
return s
default:
return tokenTypeToString[unknownToken]
}
}
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package mql
import (
"fmt"
"reflect"
"strconv"
"strings"
"golang.org/x/exp/slices"
)
type validator struct {
fn validateFunc
typ string
}
// validateFunc is used to validate a column value by converting it as needed,
// validating the value, and returning the converted value
type validateFunc func(columnValue string) (columnVal any, err error)
// fieldValidators takes a model and returns a map of field names to validate
// functions. Supported options: WithIgnoreFields
func fieldValidators(model reflect.Value, opt ...Option) (map[string]validator, error) {
const op = "mql.fieldValidators"
switch {
case !model.IsValid():
return nil, fmt.Errorf("%s: missing model: %w", op, ErrInvalidParameter)
case (model.Kind() != reflect.Struct && model.Kind() != reflect.Pointer),
model.Kind() == reflect.Pointer && model.Elem().Kind() != reflect.Struct:
return nil, fmt.Errorf("%s: model must be a struct or a pointer to a struct: %w", op, ErrInvalidParameter)
}
var m reflect.Value = model
if m.Kind() != reflect.Struct {
m = model.Elem()
}
opts, err := getOpts(opt...)
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
fValidators := make(map[string]validator)
for i := 0; i < m.NumField(); i++ {
field := m.Type().Field(i)
if slices.Contains(opts.withIgnoredFields, field.Name) {
continue
}
var fName string
switch {
case opts.withColumnFieldTag != "":
tagValue := field.Tag.Get(opts.withColumnFieldTag)
if tagValue != "" {
parts := strings.SplitN(tagValue, ",", 2)
fName = parts[0]
}
if fName == "" {
return nil, fmt.Errorf("%s: field %q has an invalid tag %q: %w", op, field.Name, opts.withColumnFieldTag, ErrInvalidParameter)
}
default:
fName = strings.ToLower(field.Name)
}
// get a string val of the field type, then strip any leading '*' so we
// can simplify the switch below when dealing with types like *int and int.
fType := strings.TrimPrefix(m.Type().Field(i).Type.String(), "*")
switch fType {
case "float32", "float64":
fValidators[fName] = validator{fn: validateFloat, typ: "float"}
case "int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64":
fValidators[fName] = validator{fn: validateInt, typ: "int"}
case "time.Time":
fValidators[fName] = validator{fn: validateDefault, typ: "time"}
default:
fValidators[fName] = validator{fn: validateDefault, typ: "default"}
}
}
return fValidators, nil
}
// by default, we'll use a no op validation
func validateDefault(s string) (any, error) {
return s, nil
}
func validateInt(s string) (any, error) {
const op = "mql.validateInt"
i, err := strconv.Atoi(s)
if err != nil {
return 0, fmt.Errorf("%s: value %q is not an int: %w", op, s, ErrInvalidParameter)
}
return i, nil
}
func validateFloat(s string) (any, error) {
const op = "mql.validateFloat"
f, err := strconv.ParseFloat(s, 64)
if err != nil {
return nil, fmt.Errorf("%s: value %q is not float: %w", op, s, ErrInvalidParameter)
}
return f, nil
}