// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package mql import ( "fmt" "reflect" ) // isNil reports if a is nil func isNil(a any) bool { if a == nil { return true } switch reflect.TypeOf(a).Kind() { case reflect.Ptr, reflect.Map, reflect.Chan, reflect.Slice, reflect.Func: return reflect.ValueOf(a).IsNil() } return false } // panicIfNil will panic if a is nil func panicIfNil(a any, caller, missing string) { if isNil(a) { panic(fmt.Sprintf("%s: missing %s", caller, missing)) } } func pointer[T any](input T) *T { return &input }
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package mql import ( "fmt" ) type exprType int const ( unknownExprType exprType = iota comparisonExprType logicalExprType ) type expr interface { Type() exprType String() string } // ComparisonOp defines a set of comparison operators type ComparisonOp string const ( GreaterThanOp ComparisonOp = ">" GreaterThanOrEqualOp ComparisonOp = ">=" LessThanOp ComparisonOp = "<" LessThanOrEqualOp ComparisonOp = "<=" EqualOp ComparisonOp = "=" NotEqualOp ComparisonOp = "!=" ContainsOp ComparisonOp = "%" ) func newComparisonOp(s string) (ComparisonOp, error) { const op = "newComparisonOp" switch ComparisonOp(s) { case GreaterThanOp, GreaterThanOrEqualOp, LessThanOp, LessThanOrEqualOp, EqualOp, NotEqualOp, ContainsOp: return ComparisonOp(s), nil default: return "", fmt.Errorf("%s: %w %q", op, ErrInvalidComparisonOp, s) } } type comparisonExpr struct { column string comparisonOp ComparisonOp value *string } // Type returns the expr type func (e *comparisonExpr) Type() exprType { return comparisonExprType } // String returns a string rep of the expr func (e *comparisonExpr) String() string { switch e.value { case nil: return fmt.Sprintf("(comparisonExpr: %s %s nil)", e.column, e.comparisonOp) default: return fmt.Sprintf("(comparisonExpr: %s %s %s)", e.column, e.comparisonOp, *e.value) } } func (e *comparisonExpr) isComplete() bool { return e.column != "" && e.comparisonOp != "" && e.value != nil } // defaultValidateConvert will validate the comparison expr value, and then convert the // expr to its SQL equivalence. func defaultValidateConvert(columnName string, comparisonOp ComparisonOp, columnValue *string, validator validator, opt ...Option) (*WhereClause, error) { const op = "mql.(comparisonExpr).convertToSql" switch { case columnName == "": return nil, fmt.Errorf("%s: %w", op, ErrMissingColumn) case comparisonOp == "": return nil, fmt.Errorf("%s: %w", op, ErrMissingComparisonOp) case isNil(columnValue): return nil, fmt.Errorf("%s: %w", op, ErrMissingComparisonValue) case validator.fn == nil: return nil, fmt.Errorf("%s: missing validator function: %w", op, ErrInvalidParameter) case validator.typ == "": return nil, fmt.Errorf("%s: missing validator type: %w", op, ErrInvalidParameter) } // everything was validated at the start, so we know this is a valid/complete comparisonExpr e := &comparisonExpr{ column: columnName, comparisonOp: comparisonOp, value: columnValue, } v, err := validator.fn(*e.value) if err != nil { return nil, fmt.Errorf("%s: %q in %s: %w", op, *e.value, e.String(), ErrInvalidParameter) } if validator.typ == "time" { columnName = fmt.Sprintf("%s::date", columnName) } switch e.comparisonOp { case ContainsOp: return &WhereClause{ Condition: fmt.Sprintf("%s like ?", columnName), Args: []any{fmt.Sprintf("%%%s%%", v)}, }, nil default: return &WhereClause{ Condition: fmt.Sprintf("%s%s?", columnName, e.comparisonOp), Args: []any{v}, }, nil } } type logicalOp string const ( andOp logicalOp = "and" orOp logicalOp = "or" ) func newLogicalOp(s string) (logicalOp, error) { const op = "newLogicalOp" switch logicalOp(s) { case andOp, orOp: return logicalOp(s), nil default: return "", fmt.Errorf("%s: %w %q", op, ErrInvalidLogicalOp, s) } } type logicalExpr struct { leftExpr expr logicalOp logicalOp rightExpr expr } // Type returns the expr type func (l *logicalExpr) Type() exprType { return logicalExprType } // String returns a string rep of the expr func (l *logicalExpr) String() string { return fmt.Sprintf("(logicalExpr: %s %s %s)", l.leftExpr, l.logicalOp, l.rightExpr) } // root will return the root of the expr tree func root(lExpr *logicalExpr, raw string) (expr, error) { const op = "mql.root" switch { // intentionally not checking raw, since can be an empty string case lExpr == nil: return nil, fmt.Errorf("%s: %w (missing expression)", op, ErrInvalidParameter) } logicalOp := lExpr.logicalOp if logicalOp != "" && lExpr.rightExpr == nil { return nil, fmt.Errorf("%s: %w in: %q", op, ErrMissingRightSideExpr, raw) } for lExpr.logicalOp == "" { switch { case lExpr.leftExpr == nil: return nil, fmt.Errorf("%s: %w nil in: %q", op, ErrMissingExpr, raw) case lExpr.leftExpr.Type() == comparisonExprType: return lExpr.leftExpr, nil default: lExpr = lExpr.leftExpr.(*logicalExpr) } } return lExpr, nil }
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package mql import ( "bufio" "bytes" "fmt" "strings" "unicode" ) // Delimiter used to quote strings type Delimiter rune const ( DoubleQuote Delimiter = '"' SingleQuote Delimiter = '\'' Backtick Delimiter = '`' backslash = '\\' ) type lexStateFunc func(*lexer) (lexStateFunc, error) type lexer struct { source *bufio.Reader current stack[rune] tokens chan token state lexStateFunc } func newLexer(s string) *lexer { l := &lexer{ source: bufio.NewReader(strings.NewReader(s)), state: lexStartState, tokens: make(chan token, 1), // define a ring buffer for emitted tokens } return l } // nextToken is the external api for the lexer and it simply returns the next // token or an error. If EOF is encountered while scanning, nextToken will keep // returning an eofToken no matter how many times you call nextToken. func (l *lexer) nextToken() (token, error) { for { select { case tk := <-l.tokens: // return a token if one has been emitted return tk, nil default: // otherwise, keep scanning via the next state var err error if l.state, err = l.state(l); err != nil { return token{}, err } } } } // lexStartState is the start state. It doesn't emit tokens, but rather // transitions to other states. Other states typically transition back to // lexStartState after they emit a token. func lexStartState(l *lexer) (lexStateFunc, error) { panicIfNil(l, "lexStartState", "lexer") r := l.read() switch { // wait, if it's eof we're done case r == eof: l.emit(eofToken, "") return lexEofState, nil // start with finding all tokens that can have a trailing "=" case r == '>': return lexGreaterState, nil case r == '<': return lexLesserState, nil // now, we can just look at the next rune... case r == '%': return lexContainsState, nil case r == '=': return lexEqualState, nil case r == '!': return lexNotEqualState, nil case r == ')': return lexRightParenState, nil case r == '(': return lexLeftParenState, nil case isSpace(r): return lexWhitespaceState, nil case unicode.IsDigit(r) || r == '.': l.unread() return lexNumberState, nil case isDelimiter(r): l.unread() return lexStringState, nil default: l.unread() return lexSymbolState, nil } } // lexStringState scans for strings and can emit a stringToken func lexStringState(l *lexer) (lexStateFunc, error) { const op = "mql.lexStringState" panicIfNil(l, "lexStringState", "lexer") defer l.current.clear() // we'll push the runes we read into this buffer and when appropriate will // emit tokens using the buffer's data. var tokenBuf bytes.Buffer // before we start looping, let's found out if we're scanning a quoted string r := l.read() delimiter := r if !isDelimiter(delimiter) { return nil, fmt.Errorf("%s: %w %q", op, ErrInvalidDelimiter, delimiter) } finalDelimiter := false WriteToBuf: // keep reading runes into the buffer until we encounter eof or the final delimiter. for { r = l.read() switch { case r == eof: break WriteToBuf case r == backslash: nextR := l.read() switch { case nextR == eof: tokenBuf.WriteRune(r) return nil, fmt.Errorf("%s: %w in %q", op, ErrInvalidTrailingBackslash, tokenBuf.String()) case nextR == backslash: tokenBuf.WriteRune(nextR) case nextR == delimiter: tokenBuf.WriteRune(nextR) default: tokenBuf.WriteRune(r) tokenBuf.WriteRune(nextR) } case r == delimiter: // end of the quoted string we're scanning finalDelimiter = true break WriteToBuf default: // otherwise, write the rune into the keyword buffer tokenBuf.WriteRune(r) } } switch { case !finalDelimiter: return nil, fmt.Errorf("%s: %w for \"%s", op, ErrMissingEndOfStringTokenDelimiter, tokenBuf.String()) default: l.emit(stringToken, tokenBuf.String()) return lexStartState, nil } } // lexSymbolState scans for strings and can emit the following tokens: // orToken, andToken, containsToken func lexSymbolState(l *lexer) (lexStateFunc, error) { const op = "mql.lexSymbolState" panicIfNil(l, "lexSymbolState", "lexer") defer l.current.clear() ReadRunes: // keep reading runes into the buffer until we encounter eof of non-text runes. for { r := l.read() switch { case r == eof: break ReadRunes case (isSpace(r) || isSpecial(r)): // whitespace or a special char l.unread() break ReadRunes default: continue ReadRunes } } switch strings.ToLower(runesToString(l.current)) { case "and": l.emit(andToken, "and") return lexStartState, nil case "or": l.emit(orToken, "or") return lexStartState, nil default: l.emit(symbolToken, runesToString(l.current)) return lexStartState, nil } } func lexNumberState(l *lexer) (lexStateFunc, error) { const op = "mql.lexNumberState" defer l.current.clear() isFloat := false // we'll push the runes we read into this buffer and when appropriate will // emit tokens using the buffer's data. var buf []rune WriteToBuf: // keep reading runes into the buffer until we encounter eof of non-number runes. for { r := l.read() switch { case r == eof: break WriteToBuf case r == '.' && isFloat: buf = append(buf, r) return nil, fmt.Errorf("%s: %w in %q", op, ErrInvalidNumber, string(buf)) case r == '.' && !isFloat: isFloat = true buf = append(buf, r) case unicode.IsDigit(r) || (r == '.' && len(buf) == 0): buf = append(buf, r) default: l.unread() break WriteToBuf } } l.emit(numberToken, string(buf)) return lexStartState, nil } // lexContainsState emits an containsToken and returns to the lexStartState func lexContainsState(l *lexer) (lexStateFunc, error) { panicIfNil(l, "lexContainsState", "lexer") defer l.current.clear() l.emit(containsToken, "%") return lexStartState, nil } // lexEqualState emits an equalToken and returns to the lexStartState func lexEqualState(l *lexer) (lexStateFunc, error) { panicIfNil(l, "lexEqualState", "lexer") defer l.current.clear() l.emit(equalToken, "=") return lexStartState, nil } // lexNotEqualState scans for a notEqualToken and return either to the lexStartState or // lexErrorState func lexNotEqualState(l *lexer) (lexStateFunc, error) { const op = "mql.lexNotEqualState" panicIfNil(l, "lexNotEqualState", "lexer") defer l.current.clear() nextRune := l.read() switch nextRune { case '=': l.emit(notEqualToken, "!=") return lexStartState, nil default: return nil, fmt.Errorf("%s: %w, got %q", op, ErrInvalidNotEqual, fmt.Sprintf("%s%s", "!", string(nextRune))) } } // lexLeftParenState emits a startLogicalExprToken and returns to the // lexStartState func lexLeftParenState(l *lexer) (lexStateFunc, error) { panicIfNil(l, "lexLeftParenState", "lexer") defer l.current.clear() l.emit(startLogicalExprToken, runesToString(l.current)) return lexStartState, nil } // lexRightParenState emits an endLogicalExprToken and returns to the // lexStartState func lexRightParenState(l *lexer) (lexStateFunc, error) { panicIfNil(l, "lexRightParenState", "lexer") defer l.current.clear() l.emit(endLogicalExprToken, runesToString(l.current)) return lexStartState, nil } // lexWhitespaceState emits a whitespaceToken and returns to the lexStartState func lexWhitespaceState(l *lexer) (lexStateFunc, error) { panicIfNil(l, "lexWhitespaceState", "lexer") defer l.current.clear() ReadWhitespace: for { ch := l.read() switch { case ch == eof: break ReadWhitespace case !isSpace(ch): l.unread() break ReadWhitespace } } l.emit(whitespaceToken, "") return lexStartState, nil } // lexGreaterState will emit either a greaterThanToken or a // greaterThanOrEqualToken and return to the lexStartState func lexGreaterState(l *lexer) (lexStateFunc, error) { panicIfNil(l, "lexGreaterState", "lexer") defer l.current.clear() next := l.read() switch next { case '=': l.emit(greaterThanOrEqualToken, ">=") return lexStartState, nil default: l.unread() l.emit(greaterThanToken, ">") return lexStartState, nil } } // lexLesserState will emit either a lessThanToken or a lessThanOrEqualToken and // return to the lexStartState func lexLesserState(l *lexer) (lexStateFunc, error) { panicIfNil(l, "lexLesserState", "lexer") defer l.current.clear() next := l.read() switch next { case '=': l.emit(lessThanOrEqualToken, "<=") return lexStartState, nil default: l.unread() l.emit(lessThanToken, "<") return lexStartState, nil } } // lexEofState will emit an eofToken and returns right back to the lexEofState func lexEofState(l *lexer) (lexStateFunc, error) { panicIfNil(l, "lexEofState", "lexer") l.emit(eofToken, "") return lexEofState, nil } // emit send a token to the lexer's token channel func (l *lexer) emit(t tokenType, v string) { l.tokens <- token{ Type: t, Value: v, } } // isSpace reports if r is a space func isSpace(r rune) bool { return r == ' ' || r == '\t' || r == '\r' || r == '\n' } // isSpecial reports r is special rune func isSpecial(r rune) bool { return r == '=' || r == '>' || r == '!' || r == '<' || r == '(' || r == ')' || r == '%' } // read the next rune func (l *lexer) read() rune { ch, _, err := l.source.ReadRune() if err != nil { return eof } l.current.push(ch) return ch } // unread the last rune read which means that rune will be returned the next // time lexer.read() is called. unread also removes the last rune from the // lexer's stack of current runes func (l *lexer) unread() { _ = l.source.UnreadRune() // error ignore which only occurs when nothing has been previously read _, _ = l.current.pop() } func isDelimiter(r rune) bool { switch Delimiter(r) { case DoubleQuote, SingleQuote, Backtick: return true default: return false } }
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package mql import ( "fmt" "reflect" "strings" ) // WhereClause contains a SQL where clause condition and its arguments. type WhereClause struct { // Condition is the where clause condition Condition string // Args for the where clause condition Args []any } // Parse will parse the query and use the provided database model to create a // where clause. Supported options: WithColumnMap, WithIgnoreFields, // WithConverter, WithPgPlaceholder func Parse(query string, model any, opt ...Option) (*WhereClause, error) { const op = "mql.Parse" switch { case query == "": return nil, fmt.Errorf("%s: missing query: %w", op, ErrInvalidParameter) case isNil(model): return nil, fmt.Errorf("%s: missing model: %w", op, ErrInvalidParameter) } p := newParser(query) expr, err := p.parse() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } fValidators, err := fieldValidators(reflect.ValueOf(model), opt...) if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } e, err := exprToWhereClause(expr, fValidators, opt...) if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } opts, err := getOpts(opt...) if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } if opts.withPgPlaceholder { for i := 0; i < len(e.Args); i++ { placeholder := fmt.Sprintf("$%d", i+1) e.Condition = strings.Replace(e.Condition, "?", placeholder, 1) } } return e, nil } // exprToWhereClause generates the where clause condition along with its // required arguments. Supported options: WithColumnMap, WithConverter func exprToWhereClause(e expr, fValidators map[string]validator, opt ...Option) (*WhereClause, error) { const op = "mql.exprToWhereClause" switch { case isNil(e): return nil, fmt.Errorf("%s: missing expression: %w", op, ErrInvalidParameter) case isNil(fValidators): return nil, fmt.Errorf("%s: missing validators: %w", op, ErrInvalidParameter) } switch v := e.(type) { case *comparisonExpr: opts, err := getOpts(opt...) if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } switch { case opts.withValidateConvertColumn == v.column && !isNil(opts.withValidateConvertFn): return opts.withValidateConvertFn(v.column, v.comparisonOp, v.value) default: columnName := strings.ToLower(v.column) if n, ok := opts.withColumnMap[columnName]; ok { columnName = n } validator, ok := fValidators[strings.ToLower(strings.ReplaceAll(columnName, "_", ""))] if !ok { cols := make([]string, len(fValidators)) for c := range fValidators { cols = append(cols, c) } return nil, fmt.Errorf("%s: %w %q %s", op, ErrInvalidColumn, columnName, cols) } w, err := defaultValidateConvert(columnName, v.comparisonOp, v.value, validator, opt...) if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } return w, nil } case *logicalExpr: left, err := exprToWhereClause(v.leftExpr, fValidators, opt...) if err != nil { return nil, fmt.Errorf("%s: invalid left expr: %w", op, err) } if v.logicalOp == "" { return nil, fmt.Errorf("%s: %w that stated with left expr condition: %q args: %q", op, ErrMissingLogicalOp, left.Condition, left.Args) } right, err := exprToWhereClause(v.rightExpr, fValidators, opt...) if err != nil { return nil, fmt.Errorf("%s: invalid right expr: %w", op, err) } return &WhereClause{ Condition: fmt.Sprintf("(%s %s %s)", left.Condition, v.logicalOp, right.Condition), Args: append(left.Args, right.Args...), }, nil default: return nil, fmt.Errorf("%s: unexpected expr type %T: %w", op, v, ErrInternal) } }
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package mql import ( "fmt" ) type options struct { withSkipWhitespace bool withColumnMap map[string]string withValidateConvertFn ValidateConvertFunc withValidateConvertColumn string withIgnoredFields []string withPgPlaceholder bool } // Option - how options are passed as args type Option func(*options) error func getDefaultOptions() options { return options{ withColumnMap: make(map[string]string), } } func getOpts(opt ...Option) (options, error) { opts := getDefaultOptions() for _, o := range opt { if err := o(&opts); err != nil { return opts, err } } return opts, nil } // withSkipWhitespace provides an option to request that whitespace be skipped func withSkipWhitespace() Option { return func(o *options) error { o.withSkipWhitespace = true return nil } } // WithColumnMap provides an optional map of columns from a column in the user // provided query to a column in the database model func WithColumnMap(m map[string]string) Option { return func(o *options) error { if !isNil(m) { o.withColumnMap = m } return nil } } // ValidateConvertFunc validates the value and then converts the columnName, // comparisonOp and value to a WhereClause type ValidateConvertFunc func(columnName string, comparisonOp ComparisonOp, value *string) (*WhereClause, error) // WithConverter provides an optional ConvertFunc for a column identifier in the // query. This allows you to provide whatever custom validation+conversion you // need on a per column basis. See: DefaultValidateConvert(...) for inspiration. func WithConverter(fieldName string, fn ValidateConvertFunc) Option { const op = "mql.WithSqlConverter" return func(o *options) error { switch { case fieldName != "" && !isNil(fn): o.withValidateConvertFn = fn o.withValidateConvertColumn = fieldName case fieldName == "" && !isNil(fn): return fmt.Errorf("%s: missing field name: %w", op, ErrInvalidParameter) case fieldName != "" && isNil(fn): return fmt.Errorf("%s: missing ConvertToSqlFunc: %w", op, ErrInvalidParameter) } return nil } } // WithIgnoredFields provides an optional list of fields to ignore in the model // (your Go struct) when parsing. Note: Field names are case sensitive. func WithIgnoredFields(fieldName ...string) Option { return func(o *options) error { o.withIgnoredFields = fieldName return nil } } // WithPgPlaceholders will use parameters placeholders that are compatible with // the postgres pg driver which requires a placeholder like $1 instead of ?. // See: // - https://pkg.go.dev/github.com/jackc/pgx/v5 // - https://pkg.go.dev/github.com/lib/pq func WithPgPlaceholders() Option { return func(o *options) error { o.withPgPlaceholder = true return nil } }
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package mql import ( "fmt" ) type parser struct { l *lexer raw string currentToken token openLogicalExpr stack[struct{}] // something very simple to make sure every logical expr that's opened is closed. } func newParser(s string) *parser { return &parser{ l: newLexer(s), raw: s, } } func (p *parser) parse() (expr, error) { const op = "mql.(parser).parse" lExpr, err := p.parseLogicalExpr() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } r, err := root(lExpr, p.raw) if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } return r, nil } // parseLogicalExpr will parse a logicalExpr until an eofToken is reached, which // may require it to parse a comparisonExpr and/or recursively parse // logicalExprs func (p *parser) parseLogicalExpr() (*logicalExpr, error) { const op = "parseLogicalExpr" logicExpr := &logicalExpr{} if err := p.scan(withSkipWhitespace()); err != nil { return nil, fmt.Errorf("%s: %w", op, err) } TkLoop: for p.currentToken.Type != eofToken { switch p.currentToken.Type { case startLogicalExprToken: // there's a opening paren: ( // so we've found a new logical expr to parse e, err := p.parseLogicalExpr() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } switch { // start by assigning the left expr case logicExpr.leftExpr == nil: logicExpr.leftExpr = e break TkLoop // we should have a logical operator before the right side expr is assigned case logicExpr.logicalOp == "": return nil, fmt.Errorf("%s: %w before right side expression in: %q", op, ErrMissingLogicalOp, p.raw) // finally, assign the right expr case logicExpr.rightExpr == nil: logicExpr.rightExpr = e.leftExpr break TkLoop } case stringToken, numberToken, symbolToken: if (logicExpr.leftExpr != nil && logicExpr.logicalOp == "") || (logicExpr.leftExpr != nil && logicExpr.rightExpr != nil) { return nil, fmt.Errorf("%s: %w starting at %q in: %q", op, ErrUnexpectedExpr, p.currentToken.Value, p.raw) } cmpExpr, err := p.parseComparisonExpr() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } switch { case logicExpr.leftExpr == nil: logicExpr.leftExpr = cmpExpr case logicExpr.rightExpr == nil: logicExpr.rightExpr = cmpExpr tmpExpr := &logicalExpr{ leftExpr: logicExpr, logicalOp: "", rightExpr: nil, } logicExpr = tmpExpr default: return nil, fmt.Errorf("%s: %w at %q, but both left and right expressions already exist in: %q", op, ErrUnexpectedExpr, p.currentToken.Value, p.raw) } case endLogicalExprToken: if logicExpr.leftExpr == nil { return nil, fmt.Errorf("%s: %w %q but we haven't parsed a left side expression in: %q", op, ErrUnexpectedClosingParen, p.currentToken.Value, p.raw) } return logicExpr, nil case andToken, orToken: if logicExpr.logicalOp != "" { return nil, fmt.Errorf("%s: %w %q when we've already parsed one for expr in: %q", op, ErrUnexpectedLogicalOp, p.currentToken.Value, p.raw) } o, err := newLogicalOp(p.currentToken.Value) if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } logicExpr.logicalOp = o default: return nil, fmt.Errorf("%s: %w %q in: %q", op, ErrUnexpectedToken, p.currentToken.Value, p.raw) } if err := p.scan(withSkipWhitespace()); err != nil { return nil, fmt.Errorf("%s: %w", op, err) } } if p.openLogicalExpr.len() > 0 { return nil, fmt.Errorf("%s: %w in: %q", op, ErrMissingClosingParen, p.raw) } return logicExpr, nil } // parseComparisonExpr will parse a comparisonExpr until an eofToken is reached, // which may require it to parse logicalExpr func (p *parser) parseComparisonExpr() (expr, error) { const op = "mql.(parser).parseComparisonExpr" cmpExpr := &comparisonExpr{} // our language (and this parser) def requires the tokens to be in the // correct order: column, comparisonOp, value. Swapping this order where the // value comes first (value, comparisonOp, column) is not supported for p.currentToken.Type != eofToken { switch { case p.currentToken.Type == startLogicalExprToken: switch { case cmpExpr.isComplete(): return nil, fmt.Errorf("%s: %w after %s in: %q", op, ErrUnexpectedOpeningParen, cmpExpr, p.raw) default: return nil, fmt.Errorf("%s: %w in: %q", op, ErrUnexpectedOpeningParen, p.raw) } // we already have a complete comparisonExpr case cmpExpr.isComplete() && (p.currentToken.Type != whitespaceToken && p.currentToken.Type != endLogicalExprToken): return nil, fmt.Errorf("%s: %w %s:%q in: %s", op, ErrUnexpectedToken, p.currentToken.Type, p.currentToken.Value, p.raw) // we found whitespace, so check if there's a completed logical expr to return case p.currentToken.Type == whitespaceToken: if cmpExpr.column != "" && cmpExpr.comparisonOp != "" && cmpExpr.value != nil { return cmpExpr, nil } // columns must come first, so handle those conditions case cmpExpr.column == "" && p.currentToken.Type != symbolToken: // this should be unreachable because parseComparisonExpr(...) is // called when a symbolToken is the current token, but I've kept // this case here for completeness return nil, fmt.Errorf("%s: %w: we expected a %s and got %s == %s in: %q", op, ErrUnexpectedToken, symbolToken, p.currentToken.Type, p.currentToken.Value, p.raw) case cmpExpr.column == "": // has to be stringToken representing the column cmpExpr.column = p.currentToken.Value // after columns, comparison operators must come next case cmpExpr.comparisonOp == "": c, err := newComparisonOp(p.currentToken.Value) if err != nil { return nil, fmt.Errorf("%s: %w %q in: %q", op, err, p.currentToken.Value, p.raw) } cmpExpr.comparisonOp = c // finally, values must come at the end case cmpExpr.value == nil && (p.currentToken.Type != stringToken && p.currentToken.Type != numberToken && p.currentToken.Type != symbolToken): return nil, fmt.Errorf("%s: %w %q in: %q", op, ErrUnexpectedToken, p.currentToken.Value, p.raw) case cmpExpr.value == nil: switch { case p.currentToken.Type == symbolToken: return nil, fmt.Errorf("%s: %w %s == %s (expected: %s or %s) in %q", op, ErrInvalidComparisonValueType, p.currentToken.Type, p.currentToken.Value, stringToken, numberToken, p.raw) case p.currentToken.Type == stringToken, p.currentToken.Type == numberToken: s := p.currentToken.Value cmpExpr.value = &s default: return nil, fmt.Errorf("%s: %w of %s == %s", op, ErrUnexpectedToken, p.currentToken.Type, p.currentToken.Value) } } if err := p.scan(); err != nil { return nil, fmt.Errorf("%s: %w", op, err) } } switch { case cmpExpr.column != "" && cmpExpr.comparisonOp == "": return nil, fmt.Errorf("%s: %w in: %q", op, ErrMissingComparisonOp, p.raw) default: return cmpExpr, nil } } // scan will get the next token from the lexer. Supported options: // withSkipWhitespace func (p *parser) scan(opt ...Option) error { const op = "mql.(parser).scan" opts, err := getOpts(opt...) if err != nil { return fmt.Errorf("%s: %w", op, err) } if p.currentToken, err = p.l.nextToken(); err != nil { return fmt.Errorf("%s: %w", op, err) } if opts.withSkipWhitespace { for p.currentToken.Type == whitespaceToken { if p.currentToken, err = p.l.nextToken(); err != nil { return fmt.Errorf("%s: %w", op, err) } } } switch p.currentToken.Type { case startLogicalExprToken: p.openLogicalExpr.push(struct{}{}) case endLogicalExprToken: p.openLogicalExpr.pop() } return nil }
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package mql type stack[T any] struct { data []T } func (s *stack[T]) push(v T) { s.data = append(s.data, v) } func (s *stack[T]) pop() (T, bool) { var x T if len(s.data) > 0 { x, s.data = s.data[len(s.data)-1], s.data[:len(s.data)-1] return x, true } return x, false } func (s *stack[T]) clear() { s.data = nil } func (s *stack[T]) len() int { return len(s.data) } func runesToString(s stack[rune]) string { return string(s.data) }
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package mql type token struct { Type tokenType Value string } type tokenType int const eof rune = -1 const ( unknownToken tokenType = iota eofToken whitespaceToken stringToken startLogicalExprToken endLogicalExprToken greaterThanToken greaterThanOrEqualToken lessThanToken lessThanOrEqualToken equalToken notEqualToken containsToken numberToken symbolToken // keywords andToken orToken ) var tokenTypeToString = map[tokenType]string{ unknownToken: "unknown", eofToken: "eof", whitespaceToken: "ws", stringToken: "str", startLogicalExprToken: "lparen", endLogicalExprToken: "rparen", greaterThanToken: "gt", greaterThanOrEqualToken: "gte", lessThanToken: "lt", lessThanOrEqualToken: "lte", equalToken: "eq", notEqualToken: "neq", containsToken: "contains", andToken: "and", orToken: "or", numberToken: "num", symbolToken: "symbol", } // String returns a string of the tokenType and will return "Unknown" for // invalid tokenTypes func (t tokenType) String() string { s, ok := tokenTypeToString[t] switch ok { case true: return s default: return tokenTypeToString[unknownToken] } }
// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package mql import ( "fmt" "reflect" "strconv" "strings" "golang.org/x/exp/slices" ) type validator struct { fn validateFunc typ string } // validateFunc is used to validate a column value by converting it as needed, // validating the value, and returning the converted value type validateFunc func(columnValue string) (columnVal any, err error) // fieldValidators takes a model and returns a map of field names to validate // functions. Supported options: WithIgnoreFields func fieldValidators(model reflect.Value, opt ...Option) (map[string]validator, error) { const op = "mql.fieldValidators" switch { case !model.IsValid(): return nil, fmt.Errorf("%s: missing model: %w", op, ErrInvalidParameter) case (model.Kind() != reflect.Struct && model.Kind() != reflect.Pointer), model.Kind() == reflect.Pointer && model.Elem().Kind() != reflect.Struct: return nil, fmt.Errorf("%s: model must be a struct or a pointer to a struct: %w", op, ErrInvalidParameter) } var m reflect.Value = model if m.Kind() != reflect.Struct { m = model.Elem() } opts, err := getOpts(opt...) if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } fValidators := make(map[string]validator) for i := 0; i < m.NumField(); i++ { if slices.Contains(opts.withIgnoredFields, m.Type().Field(i).Name) { continue } fName := strings.ToLower(m.Type().Field(i).Name) // get a string val of the field type, then strip any leading '*' so we // can simplify the switch below when dealing with types like *int and int. fType := strings.TrimPrefix(m.Type().Field(i).Type.String(), "*") switch fType { case "float32", "float64": fValidators[fName] = validator{fn: validateFloat, typ: "float"} case "int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64": fValidators[fName] = validator{fn: validateInt, typ: "int"} case "time.Time": fValidators[fName] = validator{fn: validateDefault, typ: "time"} default: fValidators[fName] = validator{fn: validateDefault, typ: "default"} } } return fValidators, nil } // by default, we'll use a no op validation func validateDefault(s string) (any, error) { return s, nil } func validateInt(s string) (any, error) { const op = "mql.validateInt" i, err := strconv.Atoi(s) if err != nil { return 0, fmt.Errorf("%s: value %q is not an int: %w", op, s, ErrInvalidParameter) } return i, nil } func validateFloat(s string) (any, error) { const op = "mql.validateFloat" f, err := strconv.ParseFloat(s, 64) if err != nil { return nil, fmt.Errorf("%s: value %q is not float: %w", op, s, ErrInvalidParameter) } return f, nil }