diff --git a/README.md b/README.md index 0f44f78..07efe79 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,34 @@ func main() { } ``` +Parsing SQL mode `ANSI_QUOTES`: + +Treat `"` as an identifier quote character (like the \` quote character) and not as a string quote character. You can still use \` to quote identifiers with this mode enabled. With `ANSI_QUOTES` enabled, you cannot use double quotation marks to quote literal strings because they are interpreted as identifiers. + +```go +package main + +import ( + "github.com/SananGuliyev/sqlparser" +) + +func main() { + sql := "SELECT * FROM table WHERE a = 'abc'" + sqlparser.SQLMode = sqlparser.SQLModeANSIQuotes + stmt, err := sqlparser.Parse(sql) + if err != nil { + // Do something with the err + } + + // Otherwise do something with stmt + switch stmt := stmt.(type) { + case *sqlparser.Select: + _ = stmt + case *sqlparser.Insert: + } +} +``` + See [parse_test.go](https://github.com/SananGuliyev/sqlparser/blob/master/parse_test.go) for more examples, or read the [godoc](https://godoc.org/github.com/SananGuliyev/sqlparser). diff --git a/analyzer.go b/analyzer.go index 4b5559f..8c4c6af 100644 --- a/analyzer.go +++ b/analyzer.go @@ -300,7 +300,7 @@ func ExtractSetValues(sql string) (keyValues map[SetKey]interface{}, scope strin if setStmt.Scope != "" && scope != "" { return nil, "", fmt.Errorf("unsupported in set: mixed using of variable scope") } - _, out := NewStringTokenizer(key).Scan() + _, out := NewStringTokenizer(key, SQLMode).Scan() key = string(out) } diff --git a/ast.go b/ast.go index 1be4c04..1326556 100644 --- a/ast.go +++ b/ast.go @@ -29,6 +29,13 @@ import ( "github.com/SananGuliyev/sqlparser/dependency/sqltypes" ) +const ( + SQLModeStandard = iota + SQLModeANSIQuotes +) + +var SQLMode = SQLModeStandard + // Instructions for creating new types: If a type // needs to satisfy an interface, declare that function // along with that interface. This will help users @@ -46,7 +53,7 @@ import ( // is partially parsed but still contains a syntax error, the // error is ignored and the DDL is returned anyway. func Parse(sql string) (Statement, error) { - tokenizer := NewStringTokenizer(sql) + tokenizer := NewStringTokenizer(sql, SQLMode) if yyParse(tokenizer) != 0 { if tokenizer.partialDDL != nil { log.Printf("ignoring error parsing DDL '%s': %v", sql, tokenizer.LastError) @@ -61,7 +68,7 @@ func Parse(sql string) (Statement, error) { // ParseStrictDDL is the same as Parse except it errors on // partially parsed DDL statements. func ParseStrictDDL(sql string) (Statement, error) { - tokenizer := NewStringTokenizer(sql) + tokenizer := NewStringTokenizer(sql, SQLMode) if yyParse(tokenizer) != 0 { return nil, tokenizer.LastError } @@ -97,7 +104,7 @@ func ParseNext(tokenizer *Tokenizer) (Statement, error) { // SplitStatement returns the first sql statement up to either a ; or EOF // and the remainder from the given buffer func SplitStatement(blob string) (string, string, error) { - tokenizer := NewStringTokenizer(blob) + tokenizer := NewStringTokenizer(blob, SQLMode) tkn := 0 for { tkn, _ = tokenizer.Scan() @@ -118,7 +125,7 @@ func SplitStatement(blob string) (string, string, error) { // returns the sql pieces blob contains; or error if sql cannot be parsed func SplitStatementToPieces(blob string) (pieces []string, err error) { pieces = make([]string, 0, 16) - tokenizer := NewStringTokenizer(blob) + tokenizer := NewStringTokenizer(blob, SQLMode) tkn := 0 var stmt string @@ -3430,6 +3437,12 @@ func Backtick(in string) string { } func formatID(buf *TrackedBuffer, original, lowered string) { + var identChar rune + if SQLMode == SQLModeANSIQuotes { + identChar = '"' + } else { + identChar = '`' + } isDbSystemVariable := false if len(original) > 1 && original[:2] == "@@" { isDbSystemVariable = true @@ -3449,14 +3462,14 @@ func formatID(buf *TrackedBuffer, original, lowered string) { return mustEscape: - buf.WriteByte('`') + _, _ = buf.WriteRune(identChar) for _, c := range original { - buf.WriteRune(c) - if c == '`' { - buf.WriteByte('`') + _, _ = buf.WriteRune(c) + if c == identChar { + _, _ = buf.WriteRune(identChar) } } - buf.WriteByte('`') + _, _ = buf.WriteRune(identChar) } func compliantName(in string) string { diff --git a/parse_next_test.go b/parse_next_test.go index bb92f9f..890ef70 100644 --- a/parse_next_test.go +++ b/parse_next_test.go @@ -67,7 +67,7 @@ func TestParseNextErrors(t *testing.T) { } sql := tcase.input + "; select 1 from t" - tokens := NewStringTokenizer(sql) + tokens := NewStringTokenizer(sql, SQLMode) // The first statement should be an error _, err := ParseNext(tokens) @@ -136,13 +136,12 @@ func TestParseNextEdgeCases(t *testing.T) { }} for _, test := range tests { - tokens := NewStringTokenizer(test.input) + tokens := NewStringTokenizer(test.input, SQLMode) for i, want := range test.want { tree, err := ParseNext(tokens) if err != nil { t.Fatalf("[%d] ParseNext(%q) err = %q, want nil", i, test.input, err) - continue } if got := String(tree); got != want { diff --git a/token.go b/token.go index 47a26d1..1157162 100644 --- a/token.go +++ b/token.go @@ -44,6 +44,7 @@ type Tokenizer struct { posVarIndex int ParseTree Statement partialDDL *DDL + sqlMode int nesting int multi bool specialComment *Tokenizer @@ -55,11 +56,12 @@ type Tokenizer struct { // NewStringTokenizer creates a new Tokenizer for the // sql string. -func NewStringTokenizer(sql string) *Tokenizer { +func NewStringTokenizer(sql string, sqlMode int) *Tokenizer { buf := []byte(sql) return &Tokenizer{ buf: buf, bufSize: len(buf), + sqlMode: sqlMode, } } @@ -595,7 +597,12 @@ func (tkn *Tokenizer) Scan() (int, []byte) { return NE, nil } return int(ch), nil - case '\'', '"': + case '\'': + return tkn.scanString(ch, STRING) + case '"': + if tkn.sqlMode == SQLModeANSIQuotes { + return tkn.scanLiteralIdentifier() + } return tkn.scanString(ch, STRING) case '`': return tkn.scanLiteralIdentifier() @@ -667,13 +674,23 @@ func (tkn *Tokenizer) scanBitLiteral() (int, []byte) { func (tkn *Tokenizer) scanLiteralIdentifier() (int, []byte) { buffer := &bytes2.Buffer{} backTickSeen := false + quoteSeen := false for { if backTickSeen { if tkn.lastChar != '`' { break } backTickSeen = false - buffer.WriteByte('`') + _ = buffer.WriteByte('`') + tkn.next() + continue + } + if quoteSeen { + if tkn.lastChar != '"' { + break + } + quoteSeen = false + _ = buffer.WriteByte('"') tkn.next() continue } @@ -681,11 +698,17 @@ func (tkn *Tokenizer) scanLiteralIdentifier() (int, []byte) { switch tkn.lastChar { case '`': backTickSeen = true + case '"': + if tkn.sqlMode == SQLModeANSIQuotes { + quoteSeen = true + } else { + _ = buffer.WriteByte(byte(tkn.lastChar)) + } case eofChar: // Premature EOF. return LEX_ERROR, buffer.Bytes() default: - buffer.WriteByte(byte(tkn.lastChar)) + _ = buffer.WriteByte(byte(tkn.lastChar)) } tkn.next() } @@ -880,7 +903,7 @@ func (tkn *Tokenizer) scanMySQLSpecificComment() (int, []byte) { tkn.consumeNext(buffer) } _, sql := ExtractMysqlComment(buffer.String()) - tkn.specialComment = NewStringTokenizer(sql) + tkn.specialComment = NewStringTokenizer(sql, SQLMode) return tkn.Scan() } diff --git a/token_test.go b/token_test.go index 9354354..e283329 100644 --- a/token_test.go +++ b/token_test.go @@ -57,7 +57,7 @@ func TestLiteralID(t *testing.T) { }} for _, tcase := range testcases { - tkn := NewStringTokenizer(tcase.in) + tkn := NewStringTokenizer(tcase.in, SQLMode) id, out := tkn.Scan() if tcase.id != id || string(out) != tcase.out { t.Errorf("Scan(%s): %d, %s, want %d, %s", tcase.in, id, out, tcase.id, tcase.out) @@ -130,7 +130,7 @@ func TestString(t *testing.T) { }} for _, tcase := range testcases { - id, got := NewStringTokenizer(tcase.in).Scan() + id, got := NewStringTokenizer(tcase.in, SQLMode).Scan() if tcase.id != id || string(got) != tcase.want { t.Errorf("Scan(%q) = (%s, %q), want (%s, %q)", tcase.in, tokenName(id), got, tokenName(tcase.id), tcase.want) } @@ -189,3 +189,31 @@ func TestSplitStatement(t *testing.T) { } } } + +func TestParseANSIQuotesMode(t *testing.T) { + testcases := []struct { + in string + out string + }{{ + in: `select * from "table"`, + out: `select * from "table"`, + }, { + in: `select * from "tbl"`, + out: `select * from tbl`, + }} + + SQLMode = SQLModeANSIQuotes + for _, tcase := range testcases { + stmt, err := Parse(tcase.in) + if err != nil { + t.Errorf("EndOfStatementPosition(%s): ERROR: %v", tcase.in, err) + continue + } + + finalSQL := String(stmt) + if tcase.out != finalSQL { + t.Errorf("EndOfStatementPosition(%s) got sql \"%s\" want \"%s\"", tcase.in, finalSQL, tcase.out) + } + } + SQLMode = SQLModeStandard +} diff --git a/tracked_buffer.go b/tracked_buffer.go index ec421a5..318dac8 100644 --- a/tracked_buffer.go +++ b/tracked_buffer.go @@ -68,7 +68,7 @@ func (buf *TrackedBuffer) Myprintf(format string, values ...interface{}) { i++ } if i > lasti { - buf.WriteString(format[lasti:i]) + _, _ = buf.WriteString(format[lasti:i]) } if i >= end { break @@ -78,18 +78,18 @@ func (buf *TrackedBuffer) Myprintf(format string, values ...interface{}) { case 'c': switch v := values[fieldnum].(type) { case byte: - buf.WriteByte(v) + _ = buf.WriteByte(v) case rune: - buf.WriteRune(v) + _, _ = buf.WriteRune(v) default: panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v)) } case 's': switch v := values[fieldnum].(type) { case []byte: - buf.Write(v) + _, _ = buf.Write(v) case string: - buf.WriteString(v) + _, _ = buf.WriteString(v) default: panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v)) } @@ -118,7 +118,7 @@ func (buf *TrackedBuffer) WriteArg(arg string) { offset: buf.Len(), length: len(arg), }) - buf.WriteString(arg) + _, _ = buf.WriteString(arg) } // ParsedQuery returns a ParsedQuery that contains bind