Skip to content
This repository has been archived by the owner on Sep 7, 2021. It is now read-only.
This repository is currently being migrated. It's locked while the migration is in progress.

Commit

Permalink
fix quote policy
Browse files Browse the repository at this point in the history
  • Loading branch information
lunny committed Sep 30, 2019
1 parent ecc286a commit cac7688
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 83 deletions.
10 changes: 5 additions & 5 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ type Engine struct {

defaultContext context.Context

quotePolicy QuotePolicy
quoteMode QuoteMode
colQuoter Quoter
tableQuoter Quoter
}

func (engine *Engine) setCacher(tableName string, cacher core.Cacher) {
Expand Down Expand Up @@ -419,7 +419,7 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
return err
}

quoter := newQuoter(dialect, engine.quoteMode, engine.quotePolicy)
colQuoter := newQuoter(dialect, engine.colQuoter.QuotePolicy())

for i, table := range tables {
if i > 0 {
Expand All @@ -440,8 +440,8 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
}

cols := table.ColumnsSeq()
colNames := quoteJoin(engine, cols)
destColNames := quoteJoin(quoter, cols)
colNames := quoteJoin(engine.colQuoter, cols)
destColNames := quoteJoin(colQuoter, cols)

rows, err := engine.DB().Query("SELECT " + colNames + " FROM " + engine.quote(table.Name, false))
if err != nil {
Expand Down
76 changes: 16 additions & 60 deletions engine_quote.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,21 @@ const (
QuoteAddReserved
)

// QuoteMode quote on which types
type QuoteMode int

// All QuoteModes
const (
QuoteTableAndColumns QuoteMode = iota
QuoteTableOnly
QuoteColumnsOnly
)

// Quoter represents an object has Quote method
type Quoter interface {
Quotes() (byte, byte)
QuotePolicy() QuotePolicy
QuoteMode() QuoteMode
IsReserved(string) bool
}

type quoter struct {
dialect core.Dialect
quoteMode QuoteMode
quotePolicy QuotePolicy
}

func newQuoter(dialect core.Dialect, quoteMode QuoteMode, quotePolicy QuotePolicy) Quoter {
func newQuoter(dialect core.Dialect, quotePolicy QuotePolicy) Quoter {
return &quoter{
dialect: dialect,
quoteMode: quoteMode,
quotePolicy: quotePolicy,
}
}
Expand All @@ -62,10 +49,6 @@ func (q *quoter) QuotePolicy() QuotePolicy {
return q.quotePolicy
}

func (q *quoter) QuoteMode() QuoteMode {
return q.quoteMode
}

func (q *quoter) IsReserved(value string) bool {
return q.dialect.IsReserved(value)
}
Expand All @@ -77,21 +60,24 @@ func quoteColumns(quoter Quoter, columnStr string) string {

func quoteJoin(quoter Quoter, columns []string) string {
for i := 0; i < len(columns); i++ {
columns[i] = quote(quoter, columns[i], true)
columns[i] = quote(quoter, columns[i])
}
return strings.Join(columns, ",")
}

// quote Use QuoteStr quote the string sql
func quote(quoter Quoter, value string, isColumn bool) string {
func quote(quoter Quoter, value string) string {
buf := strings.Builder{}
quoteTo(quoter, &buf, value, isColumn)
quoteTo(quoter, &buf, value)
return buf.String()
}

// Quote add quotes to the value
func (engine *Engine) quote(value string, isColumn bool) string {
return quote(engine, value, isColumn)
if isColumn {
return quote(engine.colQuoter, value)
}
return quote(engine.tableQuoter, value)
}

// Quote add quotes to the value
Expand All @@ -105,53 +91,25 @@ func (engine *Engine) Quotes() (byte, byte) {
return quotes[0], quotes[1]
}

// QuoteMode returns quote mode
func (engine *Engine) QuoteMode() QuoteMode {
return engine.quoteMode
}

// QuotePolicy returns quote policy
func (engine *Engine) QuotePolicy() QuotePolicy {
return engine.quotePolicy
}

// IsReserved return true if the value is a reserved word of the database
func (engine *Engine) IsReserved(value string) bool {
return engine.dialect.IsReserved(value)
}

// quoteTo quotes string and writes into the buffer
func quoteTo(quoter Quoter, buf *strings.Builder, value string, isColumn bool) {
if isColumn {
if quoter.QuoteMode() == QuoteTableAndColumns ||
quoter.QuoteMode() == QuoteColumnsOnly {
if quoter.QuotePolicy() == QuoteAddAlways {
realQuoteTo(quoter, buf, value)
return
} else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) {
realQuoteTo(quoter, buf, value)
return
}
}
buf.WriteString(value)
func quoteTo(quoter Quoter, buf *strings.Builder, value string) {
left, right := quoter.Quotes()
if quoter.QuotePolicy() == QuoteAddAlways {
realQuoteTo(left, right, buf, value)
return
} else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) {
realQuoteTo(left, right, buf, value)
return
}

if quoter.QuoteMode() == QuoteTableAndColumns ||
quoter.QuoteMode() == QuoteTableOnly {
if quoter.QuotePolicy() == QuoteAddAlways {
realQuoteTo(quoter, buf, value)
return
} else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) {
realQuoteTo(quoter, buf, value)
return
}
}
buf.WriteString(value)
return
}

func realQuoteTo(quoter Quoter, buf *strings.Builder, value string) {
func realQuoteTo(quoteLeft, quoteRight byte, buf *strings.Builder, value string) {
if buf == nil {
return
}
Expand All @@ -164,8 +122,6 @@ func realQuoteTo(quoter Quoter, buf *strings.Builder, value string) {
return
}

quoteLeft, quoteRight := quoter.Quotes()

if value[0] == '`' || value[0] == quoteLeft { // no quote
_, _ = buf.WriteString(value)
return
Expand Down
4 changes: 2 additions & 2 deletions session_find.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,15 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
if session.statement.JoinStr == "" {
if columnStr == "" {
if session.statement.GroupByStr != "" {
columnStr = quoteColumns(session.engine, session.statement.GroupByStr)
columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr)
} else {
columnStr = session.statement.genColumnStr()
}
}
} else {
if columnStr == "" {
if session.statement.GroupByStr != "" {
columnStr = quoteColumns(session.engine, session.statement.GroupByStr)
columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr)
} else {
columnStr = "*"
}
Expand Down
8 changes: 4 additions & 4 deletions session_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,15 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if session.engine.dialect.DBType() == core.ORACLE {
temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
session.engine.quote(tableName, false),
quoteJoin(session.engine, colNames))
quoteJoin(session.engine.colQuoter, colNames))
sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL",
session.engine.quote(tableName, false),
quoteJoin(session.engine, colNames),
quoteJoin(session.engine.colQuoter, colNames),
strings.Join(colMultiPlaces, temp))
} else {
sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)",
session.engine.quote(tableName, false),
quoteJoin(session.engine, colNames),
quoteJoin(session.engine.colQuoter, colNames),
strings.Join(colMultiPlaces, "),("))
}
res, err := session.exec(sql, args...)
Expand Down Expand Up @@ -855,7 +855,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {

if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
session.engine.quote(tableName, false),
quoteJoin(session.engine, columns), qm)); err != nil {
quoteJoin(session.engine.colQuoter, columns), qm)); err != nil {
return 0, err
}
w.Append(args...)
Expand Down
4 changes: 2 additions & 2 deletions session_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interfa
if session.statement.JoinStr == "" {
if columnStr == "" {
if session.statement.GroupByStr != "" {
columnStr = quoteColumns(session.engine, session.statement.GroupByStr)
columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr)
} else {
columnStr = session.statement.genColumnStr()
}
}
} else {
if columnStr == "" {
if session.statement.GroupByStr != "" {
columnStr = quoteColumns(session.engine, session.statement.GroupByStr)
columnStr = quoteColumns(session.engine.colQuoter, session.statement.GroupByStr)
} else {
columnStr = "*"
}
Expand Down
2 changes: 1 addition & 1 deletion session_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
for idx, kv := range kvs {
sps := strings.SplitN(kv, "=", 2)
sps2 := strings.Split(sps[0], ".")
colName := unQuote(session.engine, sps2[len(sps2)-1])
colName := unQuote(session.engine.colQuoter, sps2[len(sps2)-1])

if col := table.GetColumn(colName); col != nil {
fieldValue, err := col.ValueOf(bean)
Expand Down
14 changes: 7 additions & 7 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ func (statement *Statement) Cols(columns ...string) *Statement {

newColumns := statement.colmap2NewColsWithQuote()

statement.ColumnStr = quoteJoin(statement.Engine, newColumns)
statement.ColumnStr = quoteJoin(statement.Engine.colQuoter, newColumns)
return statement
}

Expand Down Expand Up @@ -638,7 +638,7 @@ func (statement *Statement) Omit(columns ...string) {
for _, nc := range newColumns {
statement.omitColumnMap = append(statement.omitColumnMap, nc)
}
statement.OmitStr = quoteJoin(statement.Engine, newColumns)
statement.OmitStr = quoteJoin(statement.Engine.colQuoter, newColumns)
}

// Nullable Update use only: update columns to null when value is nullable and zero-value
Expand Down Expand Up @@ -732,7 +732,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
}
tbs := strings.Split(tp.TableName(), ".")

var aliasName = unQuote(statement.Engine, tbs[len(tbs)-1])
var aliasName = unQuote(statement.Engine.tableQuoter, tbs[len(tbs)-1])
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
case *builder.Builder:
Expand All @@ -743,7 +743,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
}
tbs := strings.Split(tp.TableName(), ".")

var aliasName = unQuote(statement.Engine, tbs[len(tbs)-1])
var aliasName = unQuote(statement.Engine.tableQuoter, tbs[len(tbs)-1])
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
default:
Expand Down Expand Up @@ -809,7 +809,7 @@ func (statement *Statement) genColumnStr() string {
buf.WriteString(".")
}

quoteTo(statement.Engine, &buf, col.Name, true)
quoteTo(statement.Engine.colQuoter, &buf, col.Name)
}

return buf.String()
Expand Down Expand Up @@ -928,15 +928,15 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
if len(statement.JoinStr) == 0 {
if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 {
columnStr = quoteColumns(statement.Engine, statement.GroupByStr)
columnStr = quoteColumns(statement.Engine.colQuoter, statement.GroupByStr)
} else {
columnStr = statement.genColumnStr()
}
}
} else {
if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 {
columnStr = quoteColumns(statement.Engine, statement.GroupByStr)
columnStr = quoteColumns(statement.Engine.colQuoter, statement.GroupByStr)
}
}
}
Expand Down
7 changes: 5 additions & 2 deletions statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ func TestCol2NewColsWithQuote(t *testing.T) {

statement := createTestStatement()

quotedCols := quoteJoin(statement.Engine, cols)
assert.EqualValues(t, []string{statement.Engine.Quote("f1", true), statement.Engine.Quote("f2", true), statement.Engine.Quote("t3.f3", true)}, quotedCols)
quotedCols := quoteJoin(statement.Engine.colQuoter, cols)
assert.EqualValues(t, statement.Engine.Quote("f1", true)+","+
statement.Engine.Quote("f2", true)+","+
statement.Engine.Quote("t3.f3", true),
quotedCols)
}
2 changes: 2 additions & 0 deletions xorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
tagHandlers: defaultTagHandlers,
cachers: make(map[string]core.Cacher),
defaultContext: context.Background(),
colQuoter: newQuoter(dialect, QuoteAddAlways),
tableQuoter: newQuoter(dialect, QuoteAddAlways),
}

if uri.DbType == core.SQLITE {
Expand Down

0 comments on commit cac7688

Please sign in to comment.