Skip to content

Commit

Permalink
Cherry-pick cd61d85 with conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
vitess-bot[bot] committed Feb 12, 2024
1 parent f359bb9 commit 12c3605
Show file tree
Hide file tree
Showing 7 changed files with 267 additions and 55 deletions.
73 changes: 51 additions & 22 deletions go/test/endtoend/utils/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,21 @@ import (

"github.com/stretchr/testify/assert"

<<<<<<< HEAD
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/dbconfigs"
"vitess.io/vitess/go/vt/sqlparser"

=======
>>>>>>> cd61d85130 (bugfix: wrong field type returned for SUM (#15192))
"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/test/endtoend/cluster"
"vitess.io/vitess/go/vt/dbconfigs"
"vitess.io/vitess/go/vt/mysqlctl"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/sqlparser"
)

// NewMySQL creates a new MySQL server using the local mysqld binary. The name of the database
Expand Down Expand Up @@ -155,7 +163,9 @@ func prepareMySQLWithSchema(params mysql.ConnParams, sql string) error {
return nil
}

func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn, vtQr, mysqlQr *sqltypes.Result, compareColumns bool) error {
func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn, vtQr, mysqlQr *sqltypes.Result, compareColumnNames bool) error {
t.Helper()

if vtQr == nil && mysqlQr == nil {
return nil
}
Expand All @@ -168,29 +178,34 @@ func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn
return errors.New("MySQL result is 'nil' while Vitess' is not.\n")
}

var errStr string
if compareColumns {
vtColCount := len(vtQr.Fields)
myColCount := len(mysqlQr.Fields)
if vtColCount > 0 && myColCount > 0 {
if vtColCount != myColCount {
t.Errorf("column count does not match: %d vs %d", vtColCount, myColCount)
errStr += fmt.Sprintf("column count does not match: %d vs %d\n", vtColCount, myColCount)
}

var vtCols []string
var myCols []string
for i, vtField := range vtQr.Fields {
vtCols = append(vtCols, vtField.Name)
myCols = append(myCols, mysqlQr.Fields[i].Name)
}
if !assert.Equal(t, myCols, vtCols, "column names do not match - the expected values are what mysql produced") {
errStr += "column names do not match - the expected values are what mysql produced\n"
errStr += fmt.Sprintf("Not equal: \nexpected: %v\nactual: %v\n", myCols, vtCols)
}
vtColCount := len(vtQr.Fields)
myColCount := len(mysqlQr.Fields)

if vtColCount != myColCount {
t.Errorf("column count does not match: %d vs %d", vtColCount, myColCount)
}

if vtColCount > 0 {
var vtCols []string
var myCols []string
for i, vtField := range vtQr.Fields {
myField := mysqlQr.Fields[i]
checkFields(t, myField.Name, vtField, myField)

vtCols = append(vtCols, vtField.Name)
myCols = append(myCols, myField.Name)
}

if compareColumnNames && !assert.Equal(t, myCols, vtCols, "column names do not match - the expected values are what mysql produced") {
t.Errorf("column names do not match - the expected values are what mysql produced\nNot equal: \nexpected: %v\nactual: %v\n", myCols, vtCols)
}
}
<<<<<<< HEAD
stmt, err := sqlparser.Parse(query)
=======

stmt, err := sqlparser.NewTestParser().Parse(query)
>>>>>>> cd61d85130 (bugfix: wrong field type returned for SUM (#15192))
if err != nil {
t.Error(err)
return err
Expand All @@ -204,7 +219,7 @@ func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn
return nil
}

errStr += "Query (" + query + ") results mismatched.\nVitess Results:\n"
errStr := "Query (" + query + ") results mismatched.\nVitess Results:\n"
for _, row := range vtQr.Rows {
errStr += fmt.Sprintf("%s\n", row)
}
Expand All @@ -224,6 +239,20 @@ func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn
return errors.New(errStr)
}

func checkFields(t *testing.T, columnName string, vtField, myField *querypb.Field) {
t.Helper()
if vtField.Type != myField.Type {
t.Errorf("for column %s field types do not match\nNot equal: \nMySQL: %v\nVitess: %v\n", columnName, myField.Type.String(), vtField.Type.String())
}

// starting in Vitess 20, decimal types are properly sized in their field information
if BinaryIsAtLeastAtVersion(20, "vtgate") && vtField.Type == sqltypes.Decimal {
if vtField.Decimals != myField.Decimals {
t.Errorf("for column %s field decimals count do not match\nNot equal: \nMySQL: %v\nVitess: %v\n", columnName, myField.Decimals, vtField.Decimals)
}
}
}

func compareVitessAndMySQLErrors(t *testing.T, vtErr, mysqlErr error) {
if vtErr != nil && mysqlErr != nil || vtErr == nil && mysqlErr == nil {
return
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/engine/aggregations.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ func (ap *AggregateParams) String() string {

func (ap *AggregateParams) typ(inputType querypb.Type) querypb.Type {
if ap.OrigOpcode != AggregateUnassigned {
return ap.OrigOpcode.Type(inputType)
return ap.OrigOpcode.SQLType(inputType)
}
return ap.Opcode.Type(inputType)
return ap.Opcode.SQLType(inputType)
}

type aggregator interface {
Expand Down
26 changes: 25 additions & 1 deletion go/vt/vtgate/engine/opcode/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ package opcode
import (
"fmt"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/vtgate/evalengine"
)

// PulloutOpcode is a number representing the opcode
Expand Down Expand Up @@ -134,7 +136,7 @@ func (code AggregateOpcode) MarshalJSON() ([]byte, error) {
}

// Type returns the opcode return sql type, and a bool telling is we are sure about this type or not
func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type {
func (code AggregateOpcode) SQLType(typ querypb.Type) querypb.Type {
switch code {
case AggregateUnassigned:
return sqltypes.Null
Expand All @@ -159,6 +161,28 @@ func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type {
}
}

func (code AggregateOpcode) Nullable() bool {
switch code {
case AggregateCount, AggregateCountStar:
return false
default:
return true
}
}

func (code AggregateOpcode) ResolveType(t evalengine.Type, env *collations.Environment) evalengine.Type {
sqltype := code.SQLType(t.Type())
collation := collations.CollationForType(sqltype, env.DefaultConnectionCharset())
nullable := code.Nullable()
size := t.Size()

scale := t.Scale()
if code == AggregateAvg {
scale += 4
}
return evalengine.NewTypeEx(sqltype, collation, nullable, size, scale)
}

func (code AggregateOpcode) NeedsComparableValues() bool {
switch code {
case AggregateCountDistinct, AggregateSumDistinct, AggregateMin, AggregateMax:
Expand Down
133 changes: 132 additions & 1 deletion go/vt/vtgate/engine/opcode/constants_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,137 @@ import (
func TestCheckAllAggrOpCodes(t *testing.T) {
// This test is just checking that we never reach the panic when using Type() on valid opcodes
for i := AggregateOpcode(0); i < _NumOfOpCodes; i++ {
i.Type(sqltypes.Null)
i.SQLType(sqltypes.Null)
}
}
<<<<<<< HEAD
=======

func TestType(t *testing.T) {
tt := []struct {
opcode AggregateOpcode
typ querypb.Type
out querypb.Type
}{
{AggregateUnassigned, sqltypes.VarChar, sqltypes.Null},
{AggregateGroupConcat, sqltypes.VarChar, sqltypes.Text},
{AggregateGroupConcat, sqltypes.Blob, sqltypes.Blob},
{AggregateGroupConcat, sqltypes.Unknown, sqltypes.Unknown},
{AggregateMax, sqltypes.Int64, sqltypes.Int64},
{AggregateMax, sqltypes.Float64, sqltypes.Float64},
{AggregateSumDistinct, sqltypes.Unknown, sqltypes.Unknown},
{AggregateSumDistinct, sqltypes.Int64, sqltypes.Decimal},
{AggregateSumDistinct, sqltypes.Decimal, sqltypes.Decimal},
{AggregateCount, sqltypes.Int32, sqltypes.Int64},
{AggregateCountStar, sqltypes.Int64, sqltypes.Int64},
{AggregateGtid, sqltypes.VarChar, sqltypes.VarChar},
}

for _, tc := range tt {
t.Run(tc.opcode.String()+"_"+tc.typ.String(), func(t *testing.T) {
out := tc.opcode.SQLType(tc.typ)
assert.Equal(t, tc.out, out)
})
}
}

func TestType_Panic(t *testing.T) {
defer func() {
if r := recover(); r != nil {
errMsg, ok := r.(string)
assert.True(t, ok, "Expected a string panic message")
assert.Contains(t, errMsg, "ERROR", "Expected panic message containing 'ERROR'")
}
}()
AggregateOpcode(999).SQLType(sqltypes.VarChar)
}

func TestNeedsListArg(t *testing.T) {
tt := []struct {
opcode PulloutOpcode
out bool
}{
{PulloutValue, false},
{PulloutIn, true},
{PulloutNotIn, true},
{PulloutExists, false},
{PulloutNotExists, false},
}

for _, tc := range tt {
t.Run(tc.opcode.String(), func(t *testing.T) {
out := tc.opcode.NeedsListArg()
assert.Equal(t, tc.out, out)
})
}
}

func TestPulloutOpcode_MarshalJSON(t *testing.T) {
tt := []struct {
opcode PulloutOpcode
out string
}{
{PulloutValue, "\"PulloutValue\""},
{PulloutIn, "\"PulloutIn\""},
{PulloutNotIn, "\"PulloutNotIn\""},
{PulloutExists, "\"PulloutExists\""},
{PulloutNotExists, "\"PulloutNotExists\""},
}

for _, tc := range tt {
t.Run(tc.opcode.String(), func(t *testing.T) {
out, err := json.Marshal(tc.opcode)
require.NoError(t, err, "Unexpected error")
assert.Equal(t, tc.out, string(out))
})
}
}

func TestAggregateOpcode_MarshalJSON(t *testing.T) {
tt := []struct {
opcode AggregateOpcode
out string
}{
{AggregateCount, "\"count\""},
{AggregateSum, "\"sum\""},
{AggregateMin, "\"min\""},
{AggregateMax, "\"max\""},
{AggregateCountDistinct, "\"count_distinct\""},
{AggregateSumDistinct, "\"sum_distinct\""},
{AggregateGtid, "\"vgtid\""},
{AggregateCountStar, "\"count_star\""},
{AggregateGroupConcat, "\"group_concat\""},
{AggregateAnyValue, "\"any_value\""},
{AggregateAvg, "\"avg\""},
{999, "\"ERROR\""},
}

for _, tc := range tt {
t.Run(tc.opcode.String(), func(t *testing.T) {
out, err := json.Marshal(tc.opcode)
require.NoError(t, err, "Unexpected error")
assert.Equal(t, tc.out, string(out))
})
}
}

func TestNeedsComparableValues(t *testing.T) {
for i := AggregateOpcode(0); i < _NumOfOpCodes; i++ {
if i == AggregateCountDistinct || i == AggregateSumDistinct || i == AggregateMin || i == AggregateMax {
assert.True(t, i.NeedsComparableValues())
} else {
assert.False(t, i.NeedsComparableValues())
}
}
}

func TestIsDistinct(t *testing.T) {
for i := AggregateOpcode(0); i < _NumOfOpCodes; i++ {
if i == AggregateCountDistinct || i == AggregateSumDistinct {
assert.True(t, i.IsDistinct())
} else {
assert.False(t, i.IsDistinct())
}
}
}
>>>>>>> cd61d85130 (bugfix: wrong field type returned for SUM (#15192))
9 changes: 9 additions & 0 deletions go/vt/vtgate/engine/projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,19 @@ func (p *Projection) evalFields(env *evalengine.ExpressionEnv, infields []*query
fl |= uint32(querypb.MySqlFlag_NOT_NULL_FLAG)
}
fields = append(fields, &querypb.Field{
<<<<<<< HEAD
Name: col,
Type: q,
Charset: uint32(cs),
Flags: fl,
=======
Name: col,
Type: typ.Type(),
Charset: uint32(typ.Collation()),
ColumnLength: uint32(typ.Size()),
Decimals: uint32(typ.Scale()),
Flags: fl,
>>>>>>> cd61d85130 (bugfix: wrong field type returned for SUM (#15192))
})
}
return fields, nil
Expand Down
Loading

0 comments on commit 12c3605

Please sign in to comment.