Skip to content

Commit

Permalink
[feat](Nereids) support nereids hint position detaction (#39113) (#39416
Browse files Browse the repository at this point in the history
)

cherry-pick from master #39113
When use hint in wrong position or use unsupport hint, use channel(2) to
filter it out
  • Loading branch information
LiBinfeng-01 authored Aug 16, 2024
1 parent 0adf48a commit d6a0469
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 47 deletions.
18 changes: 18 additions & 0 deletions fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,19 @@ lexer grammar DorisLexer;
public void markUnclosedComment() {
has_unclosed_bracketed_comment = true;
}

// This variable will hold the external state
private boolean channel2;

// Method to set the external state
public void setChannel2(boolean value) {
this.channel2 = value;
}

// Method to decide the channel based on external state
private boolean isChannel2() {
return this.channel2;
}
}

SEMICOLON: ';';
Expand Down Expand Up @@ -654,6 +667,11 @@ BRACKETED_COMMENT
: '/*' {!isHint()}? ( BRACKETED_COMMENT | . )*? ('*/' | {markUnclosedComment();} EOF) -> channel(HIDDEN)
;

HINT_WITH_CHANNEL
: {isChannel2()}? HINT_START .*? HINT_END -> channel(2)
;


FROM_DUAL
: 'FROM' WS+ 'DUAL' -> channel(HIDDEN);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,12 @@
@SuppressWarnings({"OptionalUsedAsFieldOrParameterType", "OptionalGetWithoutIsPresent"})
public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {

private final Map<Integer, ParserRuleContext> selectHintMap;

public LogicalPlanBuilder(Map<Integer, ParserRuleContext> selectHintMap) {
this.selectHintMap = selectHintMap;
}

@SuppressWarnings("unchecked")
protected <T> T typedVisit(ParseTree ctx) {
return (T) ctx.accept(this);
Expand Down Expand Up @@ -604,7 +610,16 @@ public LogicalPlan visitRegularQuerySpecification(RegularQuerySpecificationConte
Optional.ofNullable(ctx.aggClause()),
Optional.ofNullable(ctx.havingClause()));
selectPlan = withQueryOrganization(selectPlan, ctx.queryOrganization());
return withSelectHint(selectPlan, selectCtx.selectHint());
if ((selectHintMap == null) || selectHintMap.isEmpty()) {
return selectPlan;
}
List<ParserRuleContext> selectHintContexts = Lists.newArrayList();
for (Integer key : selectHintMap.keySet()) {
if (key > selectCtx.getStart().getStopIndex() && key < selectCtx.getStop().getStartIndex()) {
selectHintContexts.add(selectHintMap.get(key));
}
}
return withSelectHint(selectPlan, selectHintContexts);
});
}

Expand Down Expand Up @@ -1785,47 +1800,50 @@ private LogicalPlan withJoinRelations(LogicalPlan input, RelationContext ctx) {
return last;
}

private LogicalPlan withSelectHint(LogicalPlan logicalPlan, SelectHintContext hintContext) {
if (hintContext == null) {
private LogicalPlan withSelectHint(LogicalPlan logicalPlan, List<ParserRuleContext> hintContexts) {
if (hintContexts.isEmpty()) {
return logicalPlan;
}
Map<String, SelectHint> hints = Maps.newLinkedHashMap();
for (HintStatementContext hintStatement : hintContext.hintStatements) {
String hintName = hintStatement.hintName.getText().toLowerCase(Locale.ROOT);
switch (hintName) {
case "set_var":
Map<String, Optional<String>> parameters = Maps.newLinkedHashMap();
for (HintAssignmentContext kv : hintStatement.parameters) {
if (kv.key != null) {
String parameterName = visitIdentifierOrText(kv.key);
Optional<String> value = Optional.empty();
if (kv.constantValue != null) {
Literal literal = (Literal) visit(kv.constantValue);
value = Optional.ofNullable(literal.toLegacyLiteral().getStringValue());
} else if (kv.identifierValue != null) {
// maybe we should throw exception when the identifierValue is quoted identifier
value = Optional.ofNullable(kv.identifierValue.getText());
for (ParserRuleContext hintContext : hintContexts) {
SelectHintContext selectHintContext = (SelectHintContext) hintContext;
for (HintStatementContext hintStatement : selectHintContext.hintStatements) {
String hintName = hintStatement.hintName.getText().toLowerCase(Locale.ROOT);
switch (hintName) {
case "set_var":
Map<String, Optional<String>> parameters = Maps.newLinkedHashMap();
for (HintAssignmentContext kv : hintStatement.parameters) {
if (kv.key != null) {
String parameterName = visitIdentifierOrText(kv.key);
Optional<String> value = Optional.empty();
if (kv.constantValue != null) {
Literal literal = (Literal) visit(kv.constantValue);
value = Optional.ofNullable(literal.toLegacyLiteral().getStringValue());
} else if (kv.identifierValue != null) {
// maybe we should throw exception when the identifierValue is quoted identifier
value = Optional.ofNullable(kv.identifierValue.getText());
}
parameters.put(parameterName, value);
}
parameters.put(parameterName, value);
}
}
hints.put(hintName, new SelectHintSetVar(hintName, parameters));
break;
case "leading":
List<String> leadingParameters = new ArrayList<String>();
for (HintAssignmentContext kv : hintStatement.parameters) {
if (kv.key != null) {
String parameterName = visitIdentifierOrText(kv.key);
leadingParameters.add(parameterName);
hints.put(hintName, new SelectHintSetVar(hintName, parameters));
break;
case "leading":
List<String> leadingParameters = new ArrayList<String>();
for (HintAssignmentContext kv : hintStatement.parameters) {
if (kv.key != null) {
String parameterName = visitIdentifierOrText(kv.key);
leadingParameters.add(parameterName);
}
}
}
hints.put(hintName, new SelectHintLeading(hintName, leadingParameters));
break;
case "ordered":
hints.put(hintName, new SelectHintOrdered(hintName));
break;
default:
break;
hints.put(hintName, new SelectHintLeading(hintName, leadingParameters));
break;
case "ordered":
hints.put(hintName, new SelectHintOrdered(hintName));
break;
default:
break;
}
}
}
return new LogicalSelectHint<>(hints, logicalPlan);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.ParserRuleContext;
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.atn.PredictionMode;
import org.antlr.v4.runtime.misc.ParseCancellationException;

import java.util.List;
import java.util.Map;
import java.util.function.Function;

/**
Expand Down Expand Up @@ -80,12 +83,41 @@ public List<String> parseDataType(String dataType) {

private <T> T parse(String sql, Function<DorisParser, ParserRuleContext> parseFunction) {
ParserRuleContext tree = toAst(sql, parseFunction);
LogicalPlanBuilder logicalPlanBuilder = new LogicalPlanBuilder();
LogicalPlanBuilder logicalPlanBuilder = new LogicalPlanBuilder(getHintMap(sql, DorisParser::selectHint));
return (T) logicalPlanBuilder.visit(tree);
}

/** get hint map */
public static Map<Integer, ParserRuleContext> getHintMap(String sql,
Function<DorisParser, ParserRuleContext> parseFunction) {
// parse hint first round
DorisLexer hintLexer = new DorisLexer(new CaseInsensitiveStream(CharStreams.fromString(sql)));
hintLexer.setChannel2(true);
CommonTokenStream hintTokenStream = new CommonTokenStream(hintLexer);

Map<Integer, ParserRuleContext> selectHintMap = Maps.newHashMap();

Token hintToken = hintTokenStream.getTokenSource().nextToken();
while (hintToken != null && hintToken.getType() != DorisLexer.EOF) {
int tokenType = hintToken.getType();
if (tokenType == DorisLexer.HINT_WITH_CHANNEL) {
String hintSql = sql.substring(hintToken.getStartIndex(), hintToken.getStopIndex() + 1);
DorisLexer newHintLexer = new DorisLexer(new CaseInsensitiveStream(CharStreams.fromString(hintSql)));
newHintLexer.setChannel2(false);
CommonTokenStream newHintTokenStream = new CommonTokenStream(newHintLexer);
DorisParser hintParser = new DorisParser(newHintTokenStream);
ParserRuleContext hintContext = parseFunction.apply(hintParser);
selectHintMap.put(hintToken.getStartIndex(), hintContext);
}
hintToken = hintTokenStream.getTokenSource().nextToken();
}
return selectHintMap;
}

/** toAst */
private ParserRuleContext toAst(String sql, Function<DorisParser, ParserRuleContext> parseFunction) {
DorisLexer lexer = new DorisLexer(new CaseInsensitiveStream(CharStreams.fromString(sql)));
lexer.setChannel2(true);
CommonTokenStream tokenStream = new CommonTokenStream(lexer);
DorisParser parser = new DorisParser(tokenStream);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,16 +341,10 @@ public void testJoinHint() {
parsePlan("select * from t1 join [broadcast] t2 on t1.key1=t2.key1")
.matches(logicalJoin().when(j -> j.getHint() == JoinHint.BROADCAST_RIGHT));

parsePlan("select * from t1 join /*+ broadcast */ t2 on t1.key1=t2.key1")
.matches(logicalJoin().when(j -> j.getHint() == JoinHint.BROADCAST_RIGHT));

// invalid hint position
parsePlan("select * from [shuffle] t1 join t2 on t1.key1=t2.key1")
.assertThrowsExactly(ParseException.class);

parsePlan("select * from /*+ shuffle */ t1 join t2 on t1.key1=t2.key1")
.assertThrowsExactly(ParseException.class);

// invalid hint content
parsePlan("select * from t1 join [bucket] t2 on t1.key1=t2.key1")
.assertThrowsExactly(ParseException.class)
Expand All @@ -361,8 +355,6 @@ public void testJoinHint() {
+ "----------------------^^^");

// invalid multiple hints
parsePlan("select * from t1 join /*+ shuffle , broadcast */ t2 on t1.key1=t2.key1")
.assertThrowsExactly(ParseException.class);

parsePlan("select * from t1 join [shuffle,broadcast] t2 on t1.key1=t2.key1")
.assertThrowsExactly(ParseException.class);
Expand Down
81 changes: 81 additions & 0 deletions regression-test/data/nereids_p0/hint/test_hint.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !select1_1 --
PhysicalResultSink
--hashAgg[GLOBAL]
----PhysicalDistribute
------hashAgg[LOCAL]
--------hashJoin[INNER_JOIN](t1.c1 = t2.c2)
----------PhysicalOlapScan[t2]
----------PhysicalDistribute
------------PhysicalOlapScan[t1]

Hint log:
Used: leading(t2 broadcast t1 )
UnUsed:
SyntaxError:

-- !select1_2 --
PhysicalResultSink
--hashAgg[GLOBAL]
----PhysicalDistribute
------hashAgg[LOCAL]
--------PhysicalStorageLayerAggregate

-- !select1_3 --
PhysicalResultSink
--hashAgg[GLOBAL]
----PhysicalDistribute
------hashAgg[LOCAL]
--------PhysicalStorageLayerAggregate

-- !select1_4 --
PhysicalResultSink
--hashAgg[GLOBAL]
----PhysicalDistribute
------hashAgg[LOCAL]
--------PhysicalStorageLayerAggregate

-- !select1_5 --
PhysicalResultSink
--hashAgg[GLOBAL]
----PhysicalDistribute
------hashAgg[LOCAL]
--------hashJoin[INNER_JOIN](t1.c1 = t2.c2)
----------PhysicalOlapScan[t2]
----------PhysicalDistribute
------------PhysicalOlapScan[t1]

Hint log:
Used: leading(t2 broadcast t1 )
UnUsed:
SyntaxError:

-- !select1_6 --
PhysicalResultSink
--hashAgg[GLOBAL]
----PhysicalDistribute
------hashAgg[LOCAL]
--------hashJoin[INNER_JOIN](t1.c1 = t2.c2)
----------PhysicalOlapScan[t2]
----------PhysicalDistribute
------------PhysicalOlapScan[t1]

Hint log:
Used: leading(t2 broadcast t1 )
UnUsed:
SyntaxError:

-- !select1_7 --
PhysicalResultSink
--hashAgg[GLOBAL]
----PhysicalDistribute
------hashAgg[LOCAL]
--------PhysicalStorageLayerAggregate

-- !select1_8 --
PhysicalResultSink
--hashAgg[GLOBAL]
----PhysicalDistribute
------hashAgg[LOCAL]
--------PhysicalStorageLayerAggregate

4 changes: 2 additions & 2 deletions regression-test/data/nereids_p0/hint/test_leading.out
Original file line number Diff line number Diff line change
Expand Up @@ -2538,8 +2538,8 @@ PhysicalResultSink
------------PhysicalOlapScan[t3]

Hint log:
Used: leading(t1 broadcast t2 t3 )
UnUsed:
Used: leading(t1 broadcast t2 broadcast t3 )
UnUsed:
SyntaxError:

-- !select95_4 --
Expand Down
60 changes: 60 additions & 0 deletions regression-test/suites/nereids_p0/hint/test_hint.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

suite("test_hint") {
// create database and tables
sql 'DROP DATABASE IF EXISTS test_hint'
sql 'CREATE DATABASE IF NOT EXISTS test_hint'
sql 'use test_hint'

// setting planner to nereids
sql 'set exec_mem_limit=21G'
sql 'set be_number_for_test=1'
sql 'set parallel_pipeline_task_num=1'
sql "set disable_nereids_rules=PRUNE_EMPTY_PARTITION"
sql 'set enable_nereids_planner=true'
sql "set ignore_shape_nodes='PhysicalProject'"
sql 'set enable_fallback_to_original_planner=false'
sql 'set runtime_filter_mode=OFF'

// create tables
sql """drop table if exists t1;"""
sql """drop table if exists t2;"""

sql """create table t1 (c1 int, c11 int) distributed by hash(c1) buckets 3 properties('replication_num' = '1');"""
sql """create table t2 (c2 int, c22 int) distributed by hash(c2) buckets 3 properties('replication_num' = '1');"""

// test hint positions, remove join in order to make sure shape stable when no use hint
qt_select1_1 """explain shape plan select /*+ leading(t2 broadcast t1) */ count(*) from t1 join t2 on c1 = c2;"""

qt_select1_2 """explain shape plan /*+ leading(t2 broadcast t1) */ select count(*) from t1;"""

qt_select1_3 """explain shape plan select /*+DBP: ROUTE={GROUP_ID(zjaq)}*/ count(*) from t1;"""

qt_select1_4 """explain shape plan/*+DBP: ROUTE={GROUP_ID(zjaq)}*/ select count(*) from t1;"""

qt_select1_5 """explain shape plan /*+ leading(t2 broadcast t1) */ select /*+ leading(t2 broadcast t1) */ count(*) from t1 join t2 on c1 = c2;"""

qt_select1_6 """explain shape plan/*+DBP: ROUTE={GROUP_ID(zjaq)}*/ select /*+ leading(t2 broadcast t1) */ count(*) from t1 join t2 on c1 = c2;"""

qt_select1_7 """explain shape plan /*+ leading(t2 broadcast t1) */ select /*+DBP: ROUTE={GROUP_ID(zjaq)}*/ count(*) from t1;"""

qt_select1_8 """explain shape plan /*+DBP: ROUTE={GROUP_ID(zjaq)}*/ select /*+DBP: ROUTE={GROUP_ID(zjaq)}*/ count(*) from t1;"""

}

0 comments on commit d6a0469

Please sign in to comment.