Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sort merge join #845

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 30 additions & 40 deletions pkg/dataset/sort_merge_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package dataset

import (
"bytes"
"io"
"sync"
)
Expand All @@ -31,6 +32,7 @@ import (
)

import (
"github.com/arana-db/arana/pkg/mysql"
"github.com/arana-db/arana/pkg/mysql/rows"
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/runtime/ast"
Expand Down Expand Up @@ -76,9 +78,9 @@ func NewSortMergeJoin(joinType ast.JoinType, joinColumn *JoinColumn, outer proto
return nil, errors.WithStack(err)
}

fields := make([]proto.Field, 0, len(outerFields)+len(innerFields))
fields = append(fields, outerFields...)
var fields []proto.Field
fields = append(fields, innerFields...)
csynineyang marked this conversation as resolved.
Show resolved Hide resolved
fields = append(fields, outerFields...)

if joinType == ast.RightJoin {
outer, inner = inner, outer
Expand Down Expand Up @@ -249,6 +251,10 @@ func (j *JoinColumn) Column() string {
return ""
}

func (j *JoinColumn) SetColumn(column string) {
j.column = column
}

func (s *SortMergeJoin) Close() error {
return nil
}
Expand All @@ -263,16 +269,12 @@ func (s *SortMergeJoin) Next() (proto.Row, error) {
outerRow, innerRow proto.Row
)

if s.LastRow() != nil {
outerRow = s.LastRow()
} else {
outerRow, err = s.getOuterRow()
if err != nil {
return nil, err
}
outerRow, err = s.getOuterRow()
if err != nil {
return nil, err
}

innerRow, err = s.getInnerRow(outerRow)
innerRow, err = s.getInnerRow()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -319,6 +321,7 @@ func (s *SortMergeJoin) innerJoin(outerRow proto.Row, innerRow proto.Row) (proto
if res, err := s.equalCompare(outerRow, innerRow, outerValue); err != nil {
return nil, err
} else {
s.SetLastInnerRow(innerRow)
return res, nil
}
}
Expand Down Expand Up @@ -469,12 +472,6 @@ func (s *SortMergeJoin) rightJoin(outerRow proto.Row, innerRow proto.Row) (proto
}

func (s *SortMergeJoin) getOuterRow() (proto.Row, error) {
nextOuterRow := s.NextOuterRow()
if nextOuterRow != nil {
s.ResetNextOuterRow()
return nextOuterRow, nil
}

leftRow, err := s.outer.Next()
if err != nil && errors.Is(err, io.EOF) {
return nil, nil
Expand All @@ -486,21 +483,7 @@ func (s *SortMergeJoin) getOuterRow() (proto.Row, error) {
return leftRow, nil
}

func (s *SortMergeJoin) getInnerRow(outerRow proto.Row) (proto.Row, error) {
if outerRow != nil {
outerValue, err := outerRow.(proto.KeyedRow).Get(s.joinColumn.Column())
if err != nil {
return nil, err
}

if s.DescartesFlag() {
innerRow := s.EqualValue(outerValue.String())
if innerRow != nil {
return innerRow, nil
}
}
}

func (s *SortMergeJoin) getInnerRow() (proto.Row, error) {
lastInnerRow := s.LastInnerRow()
if lastInnerRow != nil {
s.ResetLastInnerRow()
Expand All @@ -518,19 +501,22 @@ func (s *SortMergeJoin) getInnerRow(outerRow proto.Row) (proto.Row, error) {
return rightRow, nil
}

func (s *SortMergeJoin) resGenerate(leftRow proto.Row, rightRow proto.Row) proto.Row {
func (s *SortMergeJoin) resGenerate(rightRow proto.Row, leftRow proto.Row) proto.Row {
var (
leftValue []proto.Value
rightValue []proto.Value
res []proto.Value
realFields []proto.Field
)

if leftRow == nil && rightRow == nil {
return nil
}

leftFields, _ := s.outer.Fields()
rightFields, _ := s.inner.Fields()
leftFields, _ := s.inner.Fields()
realFields = append(realFields, leftFields[:(len(leftFields)-1)]...)
rightFields, _ := s.outer.Fields()
realFields = append(realFields, rightFields[:(len(rightFields)-1)]...)

leftValue = make([]proto.Value, len(leftFields))
rightValue = make([]proto.Value, len(rightFields))
Expand Down Expand Up @@ -560,12 +546,16 @@ func (s *SortMergeJoin) resGenerate(leftRow proto.Row, rightRow proto.Row) proto
}
}

res = append(res, leftValue...)
res = append(res, rightValue...)

fields, _ := s.Fields()
res = append(res, leftValue[:(len(leftValue)-1)]...)
res = append(res, rightValue[:(len(rightValue)-1)]...)

return rows.NewBinaryVirtualRow(fields, res)
var b bytes.Buffer
row := rows.NewTextVirtualRow(realFields, res)
_, err := row.WriteTo(&b)
if err != nil {
return nil
}
return mysql.NewTextRow(realFields, b.Bytes())
}

func (s *SortMergeJoin) equalCompare(outerRow proto.Row, innerRow proto.Row, outerValue proto.Value) (proto.Row, error) {
Expand Down Expand Up @@ -614,7 +604,7 @@ func (s *SortMergeJoin) greaterCompare(outerRow proto.Row) (proto.Row, proto.Row
innerRow = s.EqualValue(outerValue.String())
} else {
s.setDescartesFlag(NotDescartes)
innerRow, err = s.getInnerRow(outerRow)
innerRow, err = s.getInnerRow()
if err != nil {
return nil, nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/proto/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ func NewTableMetadata(name string, columnMetadataList []*ColumnMetadata, indexMe
}
}
for _, indexMetadata := range indexMetadataList {
indexName := strings.ToLower(indexMetadata.Name)
tma.Indexes[indexName] = indexMetadata
tma.Indexes[indexMetadata.ColumnName] = indexMetadata
}

return tma
Expand All @@ -66,7 +65,8 @@ type ColumnMetadata struct {
}

type IndexMetadata struct {
Name string
ColumnName string
Name string
}

var _defaultSchemaLoader SchemaLoader
Expand Down
69 changes: 51 additions & 18 deletions pkg/runtime/optimize/dml/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,12 +464,28 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt
if err != nil {
return nil, err
}
var tbLeft0 = tableLeft.Suffix()
if shardsLeft != nil {
_, tbLeft0 = shardsLeft.Smallest()
}
leftTblMeta, err := loadMetadataByTable(ctx, tbLeft0)
if err != nil {
return nil, err
}
Comment on lines +467 to +474
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error context for better debugging.

When loading metadata for the left table, add context to the error message for easier debugging.

-    return nil, err
+    return nil, errors.Wrapf(err, "failed to load metadata for left table: %s", tbLeft0)
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
var tbLeft0 = tableLeft.Suffix()
if shardsLeft != nil {
_, tbLeft0 = shardsLeft.Smallest()
}
leftTblMeta, err := loadMetadataByTable(ctx, tbLeft0)
if err != nil {
return nil, err
}
var tbLeft0 = tableLeft.Suffix()
if shardsLeft != nil {
_, tbLeft0 = shardsLeft.Smallest()
}
leftTblMeta, err = loadMetadataByTable(ctx, tbLeft0)
if err != nil {
return nil, errors.Wrapf(err, "failed to load metadata for left table: %s", tbLeft0)
}


join := from.Joins[0]
dbRight, aliasRight, tableRight, shardsRight, err := compute(join.Target)
if err != nil {
return nil, err
}
var tbRight0 = tableRight.Suffix()
if shardsRight != nil {
_, tbRight0 = shardsRight.Smallest()
}
rightTblMeta, err := loadMetadataByTable(ctx, tbRight0)
if err != nil {
return nil, err
}
Comment on lines +481 to +488
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error context for better debugging.

When loading metadata for the right table, add context to the error message for easier debugging.

-    return nil, err
+    return nil, errors.Wrapf(err, "failed to load metadata for right table: %s", tbRight0)
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
var tbRight0 = tableRight.Suffix()
if shardsRight != nil {
_, tbRight0 = shardsRight.Smallest()
}
rightTblMeta, err := loadMetadataByTable(ctx, tbRight0)
if err != nil {
return nil, err
}
var tbRight0 = tableRight.Suffix()
if shardsRight != nil {
_, tbRight0 = shardsRight.Smallest()
}
rightTblMeta, err := loadMetadataByTable(ctx, tbRight0)
if err != nil {
return nil, errors.Wrapf(err, "failed to load metadata for right table: %s", tbRight0)
}


// one db
if dbLeft == dbRight && shardsLeft == nil && shardsRight == nil {
Expand Down Expand Up @@ -517,6 +533,7 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt
return nil, errors.Errorf("not found buildKey or probeKey")
}

shouldSortMerge := shouldSortMergeJoin(leftTblMeta, rightTblMeta, leftKey, rightKey)
rewriteToSingle := func(tableSource ast.TableSourceItem, shards map[string][]string, onKey string) (proto.Plan, error) {
selectStmt := &ast.SelectStatement{
Select: stmt.Select,
Expand Down Expand Up @@ -594,30 +611,42 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt
return nil, err
}

setPlan := func(plan *dml.HashJoinPlan, buildPlan, probePlan proto.Plan, buildKey, probeKey string) {
plan.BuildKey = buildKey
plan.ProbeKey = probeKey
plan.BuildPlan = buildPlan
plan.ProbePlan = probePlan
}
var tmpPlan proto.Plan

if join.Typ == ast.InnerJoin {
setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey)
hashJoinPlan.IsFilterProbeRow = true
if shouldSortMerge {
tmpPlan = &dml.SortMergeJoin{
Stmt: stmt,
LeftQuery: leftPlan,
RightQuery: rightPlan,
JoinType: join.Typ,
LeftKey: leftKey,
RightKey: rightKey,
}
} else {
hashJoinPlan.IsFilterProbeRow = false
if join.Typ == ast.LeftJoin {
hashJoinPlan.IsReversedColumn = true
setPlan(hashJoinPlan, rightPlan, leftPlan, rightKey, leftKey)
} else if join.Typ == ast.RightJoin {
setPlan := func(plan *dml.HashJoinPlan, buildPlan, probePlan proto.Plan, buildKey, probeKey string) {
plan.BuildKey = buildKey
plan.ProbeKey = probeKey
plan.BuildPlan = buildPlan
plan.ProbePlan = probePlan
}

if join.Typ == ast.InnerJoin {
setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey)
hashJoinPlan.IsFilterProbeRow = true
} else {
return nil, errors.New("not support Join Type")
hashJoinPlan.IsFilterProbeRow = false
if join.Typ == ast.LeftJoin {
hashJoinPlan.IsReversedColumn = true
setPlan(hashJoinPlan, rightPlan, leftPlan, rightKey, leftKey)
} else if join.Typ == ast.RightJoin {
setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey)
} else {
return nil, errors.New("not support Join Type")
}
Comment on lines +626 to +645
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider refactoring for readability.

The nested if-else structure for setting the hash join plan can be refactored for better readability.

-    if join.Typ == ast.InnerJoin {
-        setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey)
-        hashJoinPlan.IsFilterProbeRow = true
-    } else {
-        hashJoinPlan.IsFilterProbeRow = false
-        if join.Typ == ast.LeftJoin {
-            hashJoinPlan.IsReversedColumn = true
-            setPlan(hashJoinPlan, rightPlan, leftPlan, rightKey, leftKey)
-        } else if join.Typ == ast.RightJoin {
-            setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey)
-        } else {
-            return nil, errors.New("not support Join Type")
-        }
-    }
+    switch join.Typ {
+    case ast.InnerJoin:
+        setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey)
+        hashJoinPlan.IsFilterProbeRow = true
+    case ast.LeftJoin:
+        hashJoinPlan.IsReversedColumn = true
+        setPlan(hashJoinPlan, rightPlan, leftPlan, rightKey, leftKey)
+        hashJoinPlan.IsFilterProbeRow = false
+    case ast.RightJoin:
+        setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey)
+        hashJoinPlan.IsFilterProbeRow = false
+    default:
+        return nil, errors.New("not support Join Type")
+    }
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
setPlan := func(plan *dml.HashJoinPlan, buildPlan, probePlan proto.Plan, buildKey, probeKey string) {
plan.BuildKey = buildKey
plan.ProbeKey = probeKey
plan.BuildPlan = buildPlan
plan.ProbePlan = probePlan
}
if join.Typ == ast.InnerJoin {
setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey)
hashJoinPlan.IsFilterProbeRow = true
} else {
return nil, errors.New("not support Join Type")
hashJoinPlan.IsFilterProbeRow = false
if join.Typ == ast.LeftJoin {
hashJoinPlan.IsReversedColumn = true
setPlan(hashJoinPlan, rightPlan, leftPlan, rightKey, leftKey)
} else if join.Typ == ast.RightJoin {
setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey)
} else {
return nil, errors.New("not support Join Type")
}
if join.Typ == ast.InnerJoin {
setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey)
hashJoinPlan.IsFilterProbeRow = true
} else {
hashJoinPlan.IsFilterProbeRow = false
if join.Typ == ast.LeftJoin {
hashJoinPlan.IsReversedColumn = true
setPlan(hashJoinPlan, rightPlan, leftPlan, rightKey, leftKey)
} else if join.Typ == ast.RightJoin {
setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey)
} else {
return nil, errors.New("not support Join Type")
}
}
```
Updated code:
```suggestion
switch join.Typ {
case ast.InnerJoin:
setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey)
hashJoinPlan.IsFilterProbeRow = true
case ast.LeftJoin:
hashJoinPlan.IsReversedColumn = true
setPlan(hashJoinPlan, rightPlan, leftPlan, rightKey, leftKey)
hashJoinPlan.IsFilterProbeRow = false
case ast.RightJoin:
setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey)
hashJoinPlan.IsFilterProbeRow = false
default:
return nil, errors.New("not support Join Type")
}

}
}

var tmpPlan proto.Plan
tmpPlan = hashJoinPlan
tmpPlan = hashJoinPlan
}

var (
analysis selectResult
Expand Down Expand Up @@ -700,6 +729,10 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt
return tmpPlan, nil
}

func shouldSortMergeJoin(leftTblMeta, rightTblMeta *proto.TableMetadata, leftKey, rightKey string) bool {
return leftTblMeta.Indexes[leftKey] != nil && rightTblMeta.Indexes[rightKey] != nil
}

func getSelectFlag(ru *rule.Rule, stmt *ast.SelectStatement) (flag uint32) {
switch len(stmt.From) {
case 1:
Expand Down
63 changes: 35 additions & 28 deletions pkg/runtime/plan/dml/rename.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package dml

import (
"context"
"fmt"
)

import (
Expand Down Expand Up @@ -52,38 +51,46 @@ func (rp RenamePlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result
}

convFields := func(fields []proto.Field) []proto.Field {
if len(rp.RenameList) != len(fields) {
panic(fmt.Sprintf("the length of field doesn't match: expect=%d, actual=%d!", len(rp.RenameList), len(fields)))
}
//if len(rp.RenameList) != len(fields) {
// panic(fmt.Sprintf("the length of field doesn't match: expect=%d, actual=%d!", len(rp.RenameList), len(fields)))
//}
Comment on lines +54 to +56
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove commented-out code

Consider removing the commented-out code to clean up the codebase.

-	//if len(rp.RenameList) != len(fields) {
-	//	panic(fmt.Sprintf("the length of field doesn't match: expect=%d, actual=%d!", len(rp.RenameList), len(fields)))
-	//}
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
//if len(rp.RenameList) != len(fields) {
// panic(fmt.Sprintf("the length of field doesn't match: expect=%d, actual=%d!", len(rp.RenameList), len(fields)))
//}


var renames map[int]struct{}
newFields := make([]proto.Field, 0, len(fields))
for i := 0; i < len(rp.RenameList); i++ {
rename := rp.RenameList[i]
name := fields[i].Name()
if rename == name {
continue
}
if renames == nil {
renames = make(map[int]struct{})
}
renames[i] = struct{}{}
f := *(fields[i].(*mysql.Field))
f.SetName(rp.RenameList[i])
f.SetOrgName(rp.RenameList[i])
newFields = append(newFields, &f)
}

if len(renames) < 1 {
return fields
}
// var renames map[int]struct{}
// for i := 0; i < len(rp.RenameList); i++ {
// rename := rp.RenameList[i]
// name := fields[i].Name()
// if rename == name {
// continue
// }
// if renames == nil {
// renames = make(map[int]struct{})
// }
// renames[i] = struct{}{}
// }

newFields := make([]proto.Field, 0, len(fields))
for i := 0; i < len(fields); i++ {
if _, ok := renames[i]; ok {
f := *(fields[i].(*mysql.Field))
f.SetName(rp.RenameList[i])
f.SetOrgName(rp.RenameList[i])
newFields = append(newFields, &f)
} else {
newFields = append(newFields, fields[i])
}
}
// if len(renames) < 1 {
// return fields
// }

// newFields := make([]proto.Field, 0, len(fields))
// for i := 0; i < len(fields); i++ {
// if _, ok := renames[i]; ok {
// f := *(fields[i].(*mysql.Field))
// f.SetName(rp.RenameList[i])
// f.SetOrgName(rp.RenameList[i])
// newFields = append(newFields, &f)
// } else {
// newFields = append(newFields, fields[i])
// }
// }
Comment on lines +66 to +93
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove commented-out code

Consider removing the commented-out code to clean up the codebase.

-	// var renames map[int]struct{}
-	// for i := 0; i < len(rp.RenameList); i++ {
-	// 	rename := rp.RenameList[i]
-	// 	name := fields[i].Name()
-	// 	if rename == name {
-	// 		continue
-	// 	}
-	// 	if renames == nil {
-	// 		renames = make(map[int]struct{})
-	// 	}
-	// 	renames[i] = struct{}{}
-	// }
-	// if len(renames) < 1 {
-	// 	return fields
-	// }
-	// newFields := make([]proto.Field, 0, len(fields))
-	// for i := 0; i < len(fields); i++ {
-	// 	if _, ok := renames[i]; ok {
-	// 		f := *(fields[i].(*mysql.Field))
-	// 		f.SetName(rp.RenameList[i])
-	// 		f.SetOrgName(rp.RenameList[i])
-	// 		newFields = append(newFields, &f)
-	// 	} else {
-	// 		newFields = append(newFields, fields[i])
-	// 	}
-	// }
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// var renames map[int]struct{}
// for i := 0; i < len(rp.RenameList); i++ {
// rename := rp.RenameList[i]
// name := fields[i].Name()
// if rename == name {
// continue
// }
// if renames == nil {
// renames = make(map[int]struct{})
// }
// renames[i] = struct{}{}
// }
newFields := make([]proto.Field, 0, len(fields))
for i := 0; i < len(fields); i++ {
if _, ok := renames[i]; ok {
f := *(fields[i].(*mysql.Field))
f.SetName(rp.RenameList[i])
f.SetOrgName(rp.RenameList[i])
newFields = append(newFields, &f)
} else {
newFields = append(newFields, fields[i])
}
}
// if len(renames) < 1 {
// return fields
// }
// newFields := make([]proto.Field, 0, len(fields))
// for i := 0; i < len(fields); i++ {
// if _, ok := renames[i]; ok {
// f := *(fields[i].(*mysql.Field))
// f.SetName(rp.RenameList[i])
// f.SetOrgName(rp.RenameList[i])
// newFields = append(newFields, &f)
// } else {
// newFields = append(newFields, fields[i])
// }
// }

return newFields
}

Expand Down
Loading
Loading