diff --git a/common.go b/common.go new file mode 100644 index 0000000..cbbdd9f --- /dev/null +++ b/common.go @@ -0,0 +1,27 @@ +// Copyright (c) HashiCorp, Inc. + +package mql + +import ( + "fmt" + "reflect" +) + +// isNil reports if a is nil +func isNil(a any) bool { + if a == nil { + return true + } + switch reflect.TypeOf(a).Kind() { + case reflect.Ptr, reflect.Map, reflect.Chan, reflect.Slice, reflect.Func: + return reflect.ValueOf(a).IsNil() + } + return false +} + +// panicIfNil will panic if a is nil +func panicIfNil(a any, caller, missing string) { + if isNil(a) { + panic(fmt.Sprintf("%s: missing %s", caller, missing)) + } +} diff --git a/common_test.go b/common_test.go new file mode 100644 index 0000000..43428dc --- /dev/null +++ b/common_test.go @@ -0,0 +1,63 @@ +package mql + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_panicIfNil(t *testing.T) { + + assert.Panics(t, func() { + panicIfNil(nil, "test", "missing bit") + }) +} + +func Test_isNil(t *testing.T) { + t.Parallel() + + var testErrNilPtr *testError + var testMapNilPtr map[string]struct{} + var testArrayNilPtr *[1]string + var testChanNilPtr *chan string + var testSliceNilPtr *[]string + var testFuncNil func() + + var testChanString chan string + + tc := []struct { + i any + want bool + }{ + {i: &testError{}, want: false}, + {i: testError{}, want: false}, + {i: &map[string]struct{}{}, want: false}, + {i: map[string]struct{}{}, want: false}, + {i: [1]string{}, want: false}, + {i: &[1]string{}, want: false}, + {i: &testChanString, want: false}, + {i: "string", want: false}, + {i: []string{}, want: false}, + {i: func() {}, want: false}, + {i: nil, want: true}, + {i: testErrNilPtr, want: true}, + {i: testMapNilPtr, want: true}, + {i: testArrayNilPtr, want: true}, + {i: testChanNilPtr, want: true}, + {i: testChanString, want: true}, + {i: testSliceNilPtr, want: true}, + {i: testFuncNil, want: true}, + } + + for i, tc := range tc { + t.Run(fmt.Sprintf("test #%d", i+1), func(t *testing.T) { + assert := assert.New(t) + assert.Equal(tc.want, isNil(tc.i)) + }) + } +} + +type testError struct{} + +func (*testError) Error() string { return "error" } diff --git a/error.go b/error.go new file mode 100644 index 0000000..26c65f0 --- /dev/null +++ b/error.go @@ -0,0 +1,7 @@ +package mql + +import "errors" + +var ( + ErrInvalidNotEqual = errors.New(`invalid "!=" token`) +) diff --git a/go.mod b/go.mod index 849570d..b1331df 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,11 @@ module mql go 1.20 + +require github.com/stretchr/testify v1.8.4 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..fa4b6e6 --- /dev/null +++ b/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/lex.go b/lex.go new file mode 100644 index 0000000..c9d3c4f --- /dev/null +++ b/lex.go @@ -0,0 +1,280 @@ +// Copyright (c) HashiCorp, Inc. + +package mql + +import ( + "bufio" + "bytes" + "fmt" + "strings" +) + +type lexStateFunc func(*lexer) (lexStateFunc, error) + +type lexer struct { + source *bufio.Reader + current stack[rune] + tokens chan token + state lexStateFunc +} + +func newLexer(s string) *lexer { + l := &lexer{ + source: bufio.NewReader(strings.NewReader(s)), + state: lexStartState, + tokens: make(chan token, 1), // define a ring buffer for emitted tokens + } + return l +} + +// nextToken is the external api for the lexer and it simply returns the next +// token or an error. If EOF is encountered while scanning, nextToken will keep +// returning an eofToken no matter how many times you call nextToken. +func (l *lexer) nextToken() (token, error) { + for { + select { + case tk := <-l.tokens: // return a token if one has been emitted + return tk, nil + default: // otherwise, keep scanning via the next state + var err error + if l.state, err = l.state(l); err != nil { + return token{}, err + } + + } + } +} + +// lexStartState is the start state. It doesn't emit tokens, but rather +// transitions to other states. Other states typically transition back to +// lexStartState after they emit a token. +func lexStartState(l *lexer) (lexStateFunc, error) { + panicIfNil(l, "startState", "lexer") + r := l.read() + switch { + // wait, if it's eof we're done + case r == eof: + l.emit(eofToken, "") + return lexEofState, nil + + // start with finding all tokens that can have a trailing "=" + case r == '>': + return lexGreaterState, nil + case r == '<': + return lexLesserState, nil + + // now, we can just look at the next rune... + case r == '%': + return lexContainsState, nil + case r == '=': + return lexEqualState, nil + case r == '!': + return lexNotEqualState, nil + case r == ')': + return lexRightParenState, nil + case r == '(': + return lexLeftParenState, nil + case isSpace(r): + return lexWhitespaceState, nil + default: + l.unread() + return lexStringState, nil + } +} + +// lexStringState scans for strings and can emit the following tokens: +// orToken, andToken, containsToken, stringToken +func lexStringState(l *lexer) (lexStateFunc, error) { + panicIfNil(l, "lexStringState", "lexer") + defer l.current.clear() + + // we'll push the runes we read into this buffer and when appropriate will + // emit tokens using the buffer's data. + var buf bytes.Buffer + + // before we start looping, let's found out if we're scanning a quoted string + r := l.read() + var quotedString bool + switch r { + case '"': + quotedString = true + default: + l.unread() + } + +WRITE_TO_BUF: + // keep reading runes into the buffer until we encounter eof of non-text runes. + for { + r = l.read() + switch { + case r == eof: + break WRITE_TO_BUF + case r == '"' && quotedString: // end of the quoted string we're scanning + break WRITE_TO_BUF + case (isSpace(r) || isSpecial(r)) && !quotedString: // whitespace or a special char, and we're not scanning a quoted string + l.unread() + break WRITE_TO_BUF + default: // otherwise, write the rune into the keyword buffer + buf.WriteRune(r) + } + } + + // before emitting a token, do we have a special string? + switch strings.ToLower(buf.String()) { + case "and": + l.emit(andToken, "and") + return lexStartState, nil + case "or": + l.emit(orToken, "or") + return lexStartState, nil + } + + l.emit(stringToken, buf.String()) + return lexStartState, nil +} + +// lexContainsState emits an containsToken and returns to the lexStartState +func lexContainsState(l *lexer) (lexStateFunc, error) { + panicIfNil(l, "lexContainsState", "lexer") + defer l.current.clear() + l.emit(containsToken, "%") + return lexStartState, nil +} + +// lexEqualState emits an equalToken and returns to the lexStartState +func lexEqualState(l *lexer) (lexStateFunc, error) { + panicIfNil(l, "lexEqualState", "lexer") + defer l.current.clear() + l.emit(equalToken, "=") + return lexStartState, nil +} + +// lexNotEqualState scans for a notEqualToken and return either to the lexStartState or +// lexErrorState +func lexNotEqualState(l *lexer) (lexStateFunc, error) { + const op = "mql.lexNotEqualState" + panicIfNil(l, "lexNotEqualState", "lexer") + defer l.current.clear() + nextRune := l.read() + switch nextRune { + case '=': + l.emit(notEqualToken, "!=") + return lexStartState, nil + default: + return nil, fmt.Errorf("%s: %w, got %q", op, ErrInvalidNotEqual, fmt.Sprintf("%s%s", "!", string(nextRune))) + } +} + +// lexLeftParenState emits a startLogicalExprToken and returns to the +// lexStartState +func lexLeftParenState(l *lexer) (lexStateFunc, error) { + panicIfNil(l, "lexLeftParenState", "lexer") + defer l.current.clear() + l.emit(startLogicalExprToken, runesToString(l.current)) + return lexStartState, nil +} + +// lexRightParenState emits an endLogicalExprToken and returns to the +// lexStartState +func lexRightParenState(l *lexer) (lexStateFunc, error) { + panicIfNil(l, "lexRightParenState", "lexer") + defer l.current.clear() + l.emit(endLogicalExprToken, runesToString(l.current)) + return lexStartState, nil +} + +// lexWhitespaceState emits a whitespaceToken and returns to the lexStartState +func lexWhitespaceState(l *lexer) (lexStateFunc, error) { + panicIfNil(l, "lexWhitespaceState", "lexer") + defer l.current.clear() +READ_WHITESPACE: + for { + ch := l.read() + switch { + case ch == eof: + break READ_WHITESPACE + case !isSpace(ch): + l.unread() + break READ_WHITESPACE + } + } + l.emit(whitespaceToken, "") + return lexStartState, nil +} + +// lexlGreaterState will emit either a greaterThanToken or a +// greaterThanOrEqualToken and return to the lexStartState +func lexGreaterState(l *lexer) (lexStateFunc, error) { + panicIfNil(l, "lexGreaterState", "lexer") + defer l.current.clear() + next := l.read() + switch next { + case '=': + l.emit(greaterThanOrEqualToken, ">=") + return lexStartState, nil + default: + l.unread() + l.emit(greaterThanToken, ">") + return lexStartState, nil + } +} + +// lexLesserState will emit either a lessThanToken or a lessThanOrEqualToken and +// return to the lexStartState +func lexLesserState(l *lexer) (lexStateFunc, error) { + panicIfNil(l, "lexLesserState", "lexer") + defer l.current.clear() + next := l.read() + switch next { + case '=': + l.emit(lessThanOrEqualToken, "<=") + return lexStartState, nil + default: + l.unread() + l.emit(lessThanToken, "<") + return lexStartState, nil + } +} + +// lexEofState will emit an eofToken and returns right back to the lexEofState +func lexEofState(l *lexer) (lexStateFunc, error) { + panicIfNil(l, "lexEofState", "lexer") + l.emit(eofToken, "") + return lexEofState, nil +} + +// emit send a token to the lexer's token channel +func (l *lexer) emit(t tokenType, v string) { + l.tokens <- token{ + Type: t, + Value: v, + } +} + +// isSpace reports if r is a space +func isSpace(r rune) bool { + return r == ' ' || r == '\t' || r == '\r' || r == '\n' +} + +// isSpecial reports r is special rune +func isSpecial(r rune) bool { + return r == '=' || r == '>' || r == '!' || r == '<' || r == '(' || r == ')' || r == '%' +} + +// read the next rune +func (l *lexer) read() rune { + ch, _, err := l.source.ReadRune() + if err != nil { + return eof + } + l.current.push(ch) + return ch +} + +// unread the last rune read which means that rune will be returned the next +// time lexer.read() is called. unread also removes the last rune from the +// lexer's stack of current runes +func (l *lexer) unread() { + _ = l.source.UnreadRune() // error ignore which only occurs when nothing has been previously read + _, _ = l.current.pop() +} diff --git a/lex_test.go b/lex_test.go new file mode 100644 index 0000000..ddcab0c --- /dev/null +++ b/lex_test.go @@ -0,0 +1,275 @@ +package mql + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_lexKeywordState(t *testing.T) { + t.Parallel() + tests := []struct { + name string + raw string + want token + wantErrIs error + wantErrContains string + }{ + { + name: "just-eof", + raw: ``, + want: token{ + Type: eofToken, + Value: "", + }, + }, + { + name: "empty-quotes", + raw: `""`, + want: token{ + Type: stringToken, + Value: ``, + }, + }, + { + name: "quoted-value", + raw: `"value"`, + want: token{ + Type: stringToken, + Value: `value`, + }, + }, + { + name: "non-quoted-value", + raw: "non-quoted-value", + want: token{ + Type: stringToken, + Value: "non-quoted-value", + }, + }, + { + name: "greater-than-in-keyword", + raw: "greater>than", + want: token{ + Type: stringToken, + Value: "greater", + }, + }, + { + name: "%", + raw: "%", + want: token{ + Type: containsToken, + Value: "%", + }, + }, + { + name: "and", + raw: "and ", + want: token{ + Type: andToken, + Value: "and", + }, + }, + { + name: "or", + raw: "or ", + want: token{ + Type: orToken, + Value: "or", + }, + }, + { + name: "greaterThan", + raw: ">", + want: token{ + Type: greaterThanToken, + Value: ">", + }, + }, + { + name: "greaterThanOrEqual", + raw: ">=", + want: token{ + Type: greaterThanOrEqualToken, + Value: ">=", + }, + }, + { + name: "lessThan", + raw: "<", + want: token{ + Type: lessThanToken, + Value: "<", + }, + }, + { + name: "lessThanOrEqual", + raw: "<=", + want: token{ + Type: lessThanOrEqualToken, + Value: "<=", + }, + }, + { + name: "equal", + raw: "=", + want: token{ + Type: equalToken, + Value: "=", + }, + }, + { + name: "notEqual", + raw: "!=", + want: token{ + Type: notEqualToken, + Value: "!=", + }, + }, + { + name: "notEqualError", + raw: "!not", + want: token{ + Type: errToken, + Value: `mql.lexNotEqualState: unexpected "=" after "!"`, + }, + wantErrIs: ErrInvalidNotEqual, + wantErrContains: `mql.lexNotEqualState: invalid "!=" token, got "!n"`, + }, + { + name: "startLogicalExpr", + raw: "(", + want: token{ + Type: startLogicalExprToken, + Value: "(", + }, + }, + { + name: "endLogicalExpr", + raw: ")", + want: token{ + Type: endLogicalExprToken, + Value: ")", + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + lex := newLexer(tc.raw) + tk, err := lex.nextToken() + if tc.wantErrContains != "" { + require.Error(err) + if tc.wantErrIs != nil { + assert.ErrorIs(err, tc.wantErrIs) + } + assert.ErrorContains(err, tc.wantErrContains) + return + } + require.NoError(err) + require.NotEqualValues(tk.Type, + whitespaceToken, + startLogicalExprToken, + endLogicalExprToken, + greaterThanOp, + greaterThanOrEqualOp, + lessThanOp, + lessThanOrEqualOp, + equalOp, + notEqualOp, + containsOp, + ) + assert.Equal(tc.want, tk) + if tk.Type == eofToken { + tk, err = lex.nextToken() + require.NoError(err) + assert.Equal(tc.want, tk) + } + }) + } + +} + +func Test_lexWhitespaceState(t *testing.T) { + t.Parallel() + tests := []struct { + name string + raw string + want token + }{ + { + name: "leading-whitespace", + raw: " leading", + want: token{ + Type: whitespaceToken, + Value: "", + }, + }, + { + name: "trailing-whitespace", + raw: " ", + want: token{ + Type: whitespaceToken, + Value: "", + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + lex := newLexer(tc.raw) + tk, err := lex.nextToken() + require.NoError(err) + require.NotEqualValues(tk.Type, + errToken, + eofToken, + stringToken, + startLogicalExprToken, + endLogicalExprToken, + greaterThanOp, + greaterThanOrEqualOp, + lessThanOp, + lessThanOrEqualOp, + equalOp, + notEqualOp, + containsOp, + ) + assert.Equal(tc.want, tk) + }) + } + +} + +// Fuzz_lexerNextToken is only focused on finding panics +func Fuzz_lexerNextToken(f *testing.F) { + tc := []string{">=!=", "string ( ) > >=", "< <= = != AND OR and or", "1 != \"2\""} + for _, tc := range tc { + f.Add(tc) + } + f.Fuzz(func(t *testing.T, s string) { + helperFn := func(lex *lexer) []token { + var tokens []token + for { + tok, err := lex.nextToken() + if err != nil { + return tokens + } + tokens = append(tokens, tok) + if tok.Type == eofToken { + return tokens + } + } + } + lex := newLexer(s) + tokens := helperFn(lex) + for _, token := range tokens { + if token.Type.String() == "Unknown" { + t.Errorf("unexpected token %v", token) + } + } + }) +} diff --git a/stack.go b/stack.go new file mode 100644 index 0000000..561374c --- /dev/null +++ b/stack.go @@ -0,0 +1,36 @@ +// Copyright (c) HashiCorp, Inc. + +package mql + +type stack[T any] struct { + data []T +} + +func (s *stack[T]) push(v T) { + s.data = append(s.data, v) +} + +func (s *stack[T]) pop() (T, bool) { + var x T + if len(s.data) > 0 { + x, s.data = s.data[len(s.data)-1], s.data[:len(s.data)-1] + return x, true + } + return x, false +} + +func (s *stack[T]) clear() { + s.data = nil +} + +func runesToString(s stack[rune]) string { + var result string + for { + r, ok := s.pop() + if !ok { + break + } + result = string(r) + result + } + return result +} diff --git a/token.go b/token.go new file mode 100644 index 0000000..fdc2d08 --- /dev/null +++ b/token.go @@ -0,0 +1,64 @@ +// Copyright (c) HashiCorp, Inc. + +package mql + +type token struct { + Type tokenType + Value string +} + +type tokenType int + +const eof rune = -1 + +const ( + unknownToken tokenType = iota + errToken + eofToken + whitespaceToken + stringToken + startLogicalExprToken + endLogicalExprToken + greaterThanToken + greaterThanOrEqualToken + lessThanToken + lessThanOrEqualToken + equalToken + notEqualToken + containsToken + + // keywords + andToken + orToken +) + +var tokenTypeToString = map[tokenType]string{ + unknownToken: "Unknown", + errToken: "Error", + eofToken: "Eof", + whitespaceToken: "Whitespace", + stringToken: "String", + startLogicalExprToken: "(", + endLogicalExprToken: ")", + greaterThanToken: "Greater Than", + greaterThanOrEqualToken: "Greater Than Or Equal", + lessThanToken: "Less Than", + lessThanOrEqualToken: "Less Than or Equal", + equalToken: "Equal", + notEqualToken: "Not Equal", + containsToken: "Contains", + andToken: "And", + orToken: "Or", +} + +// String returns a string of the tokenType and will return "Unknown" for +// invalid tokenTypes +func (t tokenType) String() string { + s, ok := tokenTypeToString[t] + switch ok { + case true: + return s + default: + return "Unknown" + } +} diff --git a/token_test.go b/token_test.go new file mode 100644 index 0000000..f3d9c8b --- /dev/null +++ b/token_test.go @@ -0,0 +1,18 @@ +package mql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_lexerString(t *testing.T) { + for typ, s := range tokenTypeToString { + assert := assert.New(t) + assert.Equal(s, typ.String()) + } + t.Run("unknown-tokenType", func(t *testing.T) { + typ := tokenType(-1) + assert.Equal(t, "Unknown", typ.String()) + }) +}