Skip to content

Commit

Permalink
[feat](Nereids): Refactor Eliminate_Group_By_Key by functional depend…
Browse files Browse the repository at this point in the history
…encies (#34948)
  • Loading branch information
keanji-x authored May 22, 2024
1 parent 3c8a6ee commit 01fa2e6
Show file tree
Hide file tree
Showing 12 changed files with 560 additions and 388 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
import org.apache.doris.nereids.trees.expressions.Slot;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;

Expand Down Expand Up @@ -71,6 +74,37 @@ public int size() {
return items.size();
}

public boolean isEmpty() {
return items.isEmpty();
}

/**
* Eliminate all deps in slots
*/
public Set<Set<Slot>> eliminateDeps(Set<Set<Slot>> slots) {
Set<Set<Slot>> minSlotSet = slots;
List<Set<Set<Slot>>> reduceSlotSets = new ArrayList<>();
reduceSlotSets.add(slots);
while (!reduceSlotSets.isEmpty()) {
List<Set<Set<Slot>>> newReduceSlotSets = new ArrayList<>();
for (Set<Set<Slot>> slotSet : reduceSlotSets) {
for (FuncDepsItem funcDepsItem : items) {
if (slotSet.contains(funcDepsItem.dependencies)
&& slotSet.contains(funcDepsItem.determinants)) {
Set<Set<Slot>> newSet = Sets.newHashSet(slotSet);
newSet.remove(funcDepsItem.dependencies);
if (minSlotSet.size() > newSet.size()) {
minSlotSet = newSet;
}
newReduceSlotSets.add(newSet);
}
}
}
reduceSlotSets = newReduceSlotSets;
}
return minSlotSet;
}

public boolean isFuncDeps(Set<Slot> dominate, Set<Slot> dependency) {
return items.contains(new FuncDepsItem(dominate, dependency));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.doris.nereids.properties;

import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.util.ImmutableEqualSet;

import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -196,6 +197,9 @@ public void addFuncDepsDG(FunctionalDependencies fd) {
}

public void addDeps(Set<Slot> dominate, Set<Slot> dependency) {
if (dominate.containsAll(dependency)) {
return;
}
fdDgBuilder.addDeps(dominate, dependency);
}

Expand Down Expand Up @@ -265,19 +269,32 @@ public List<Set<Slot>> calEqualSetList() {
/**
* get all unique slots
*/
public List<Set<Slot>> getAllUnique() {
List<Set<Slot>> res = new ArrayList<>(uniqueSet.slotSets);
for (Slot s : uniqueSet.slots) {
res.add(ImmutableSet.of(s));
public List<Set<Slot>> getAllUniqueAndNotNull() {
List<Set<Slot>> res = new ArrayList<>();
for (Slot slot : uniqueSet.slots) {
if (!slot.nullable()) {
res.add(ImmutableSet.of(slot));
}
}
for (Set<Slot> slotSet : uniqueSet.slotSets) {
if (slotSet.stream().noneMatch(ExpressionTrait::nullable)) {
res.add(slotSet);
}
}
return res;
}

/**
* get all uniform slots
*/
public Set<Slot> getAllUniform() {
return uniformSet.slots;
public List<Set<Slot>> getAllUniformAndNotNull() {
List<Set<Slot>> res = new ArrayList<>();
for (Slot s : uniformSet.slots) {
if (!s.nullable()) {
res.add(ImmutableSet.of(s));
}
}
return res;
}

public void addEqualPair(Slot l, Slot r) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ public enum RuleType {
ELIMINATE_JOIN_BY_UK(RuleTypeClass.REWRITE),
ELIMINATE_JOIN_BY_FK(RuleTypeClass.REWRITE),
ELIMINATE_GROUP_BY_KEY(RuleTypeClass.REWRITE),
ELIMINATE_FILTER_GROUP_BY_KEY(RuleTypeClass.REWRITE),
ELIMINATE_DEDUP_JOIN_CONDITION(RuleTypeClass.REWRITE),
ELIMINATE_NULL_AWARE_LEFT_ANTI_JOIN(RuleTypeClass.REWRITE),
ELIMINATE_ASSERT_NUM_ROWS(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,193 +17,101 @@

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.properties.FdItem;
import org.apache.doris.nereids.annotation.DependsRules;
import org.apache.doris.nereids.properties.FuncDeps;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableList;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.stream.Collectors;


/**
* Eliminate group by key based on fd item information.
* such as:
* for a -> b, we can get:
* group by a, b, c => group by a, c
*/
public class EliminateGroupByKey extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalAggregate(logicalProject()).then(agg -> {
LogicalPlan childPlan = agg.child();
List<FdItem> uniqueFdItems = new ArrayList<>();
List<FdItem> nonUniqueFdItems = new ArrayList<>();
if (agg.getGroupByExpressions().isEmpty()
|| !agg.getGroupByExpressions().stream().allMatch(e -> e instanceof SlotReference)) {
return null;
}
ImmutableSet<FdItem> fdItems = childPlan.getLogicalProperties().getFunctionalDependencies().getFdItems();
if (fdItems.isEmpty()) {
return null;
}
List<SlotReference> candiExprs = agg.getGroupByExpressions().stream()
.map(SlotReference.class::cast).collect(Collectors.toList());
@DependsRules({EliminateGroupBy.class, ColumnPruning.class})
public class EliminateGroupByKey implements RewriteRuleFactory {

fdItems.stream().filter(e -> !e.isCandidate()).forEach(e -> {
if (e.isUnique()) {
uniqueFdItems.add(e);
} else {
nonUniqueFdItems.add(e);
}
}
);
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
RuleType.ELIMINATE_GROUP_BY_KEY.build(
logicalProject(logicalAggregate().when(agg -> !agg.getSourceRepeat().isPresent()))
.then(proj -> {
LogicalAggregate<? extends Plan> agg = proj.child();
LogicalAggregate<Plan> newAgg = eliminateGroupByKey(agg, proj.getInputSlots());
if (newAgg == null) {
return null;
}
return proj.withChildren(newAgg);
})),
RuleType.ELIMINATE_FILTER_GROUP_BY_KEY.build(
logicalProject(logicalFilter(logicalAggregate()
.when(agg -> !agg.getSourceRepeat().isPresent())))
.then(proj -> {
LogicalAggregate<? extends Plan> agg = proj.child().child();
Set<Slot> requireSlots = new HashSet<>(proj.getInputSlots());
requireSlots.addAll(proj.child(0).getInputSlots());
LogicalAggregate<Plan> newAgg = eliminateGroupByKey(agg, proj.getOutputSet());
if (newAgg == null) {
return null;
}
return proj.withChildren(proj.child().withChildren(newAgg));
})
)
);
}

int minParentExprCnt = -1;
ImmutableSet<SlotReference> minParentExprs = ImmutableSet.of();
// if unique fd items exists, try to find the one which has the
// smallest parent exprs
for (int i = 0; i < uniqueFdItems.size(); i++) {
FdItem fdItem = uniqueFdItems.get(i);
ImmutableSet<SlotReference> parentExprs = fdItem.getParentExprs();
if (minParentExprCnt == -1 || parentExprs.size() < minParentExprCnt) {
boolean isContain = isExprsContainFdParent(candiExprs, fdItem);
if (isContain) {
minParentExprCnt = parentExprs.size();
minParentExprs = ImmutableSet.copyOf(parentExprs);
}
}
}
LogicalAggregate<Plan> eliminateGroupByKey(LogicalAggregate<? extends Plan> agg, Set<Slot> requireOutput) {
Map<Expression, Set<Slot>> groupBySlots = new HashMap<>();
Set<Slot> validSlots = new HashSet<>();
for (Expression expression : agg.getGroupByExpressions()) {
groupBySlots.put(expression, expression.getInputSlots());
validSlots.addAll(expression.getInputSlots());
}

Set<Integer> rootExprsSet = new HashSet<>();
List<SlotReference> rootExprs = new ArrayList<>();
Set<Integer> eliminateSet = new HashSet<>();
if (minParentExprs.size() > 0) {
// if any unique fd item found, find the expr which matching parentExprs
// from candiExprs directly
for (int i = 0; i < minParentExprs.size(); i++) {
int index = findEqualExpr(candiExprs, minParentExprs.asList().get(i));
if (index != -1) {
rootExprsSet.add(new Integer(index));
} else {
return null;
}
}
} else {
// no unique fd item found, try to find the smallest root exprs set
// from non-unique fd items.
for (int i = 0; i < nonUniqueFdItems.size() && eliminateSet.size() < candiExprs.size(); i++) {
FdItem fdItem = nonUniqueFdItems.get(i);
ImmutableSet<SlotReference> parentExprs = fdItem.getParentExprs();
boolean isContains = isExprsContainFdParent(candiExprs, fdItem);
if (isContains) {
List<SlotReference> leftDomain = new ArrayList<>();
List<SlotReference> rightDomain = new ArrayList<>();
// generate new root exprs
for (int j = 0; j < rootExprs.size(); j++) {
leftDomain.add(rootExprs.get(j));
boolean isInChild = fdItem.checkExprInChild(rootExprs.get(j), childPlan);
if (!isInChild) {
rightDomain.add(rootExprs.get(j));
}
}
for (int j = 0; j < parentExprs.size(); j++) {
int index = findEqualExpr(candiExprs, parentExprs.asList().get(j));
if (index != -1) {
rightDomain.add(candiExprs.get(index));
if (!eliminateSet.contains(index)) {
leftDomain.add(candiExprs.get(index));
}
}
}
// check fd can eliminate new candi expr
for (int j = 0; j < candiExprs.size(); j++) {
if (!eliminateSet.contains(j)) {
boolean isInChild = fdItem.checkExprInChild(candiExprs.get(j), childPlan);
if (isInChild) {
eliminateSet.add(j);
}
}
}
// if fd eliminate new candi exprs or new root exprs is smaller than the older,
// than use new root expr to replace old ones
List<SlotReference> newRootExprs = leftDomain.size() <= rightDomain.size()
? leftDomain : rightDomain;
rootExprs.clear();
rootExprs.addAll(newRootExprs);
}
}
}
// find the root expr, add into root exprs set, indicate the index in
// candiExprs list
for (int i = 0; i < rootExprs.size(); i++) {
int index = findEqualExpr(candiExprs, rootExprs.get(i));
if (index != -1) {
rootExprsSet.add(new Integer(index));
} else {
return null;
}
}
// other can't be determined expr, add into root exprs directly
if (eliminateSet.size() < candiExprs.size()) {
for (int i = 0; i < candiExprs.size(); i++) {
if (!eliminateSet.contains(i)) {
rootExprsSet.add(i);
}
}
}
rootExprs.clear();
for (int i = 0; i < candiExprs.size(); i++) {
if (rootExprsSet.contains(i)) {
rootExprs.add(candiExprs.get(i));
}
}
FuncDeps funcDeps = agg.child().getLogicalProperties()
.getFunctionalDependencies().getAllValidFuncDeps(validSlots);
if (funcDeps.isEmpty()) {
return null;
}

// use the new rootExprs as new group by keys
List<Expression> resultExprs = new ArrayList<>();
for (int i = 0; i < rootExprs.size(); i++) {
resultExprs.add(rootExprs.get(i));
Set<Set<Slot>> minGroupBySlots = funcDeps.eliminateDeps(new HashSet<>(groupBySlots.values()));
Set<Expression> removeExpression = new HashSet<>();
for (Entry<Expression, Set<Slot>> entry : groupBySlots.entrySet()) {
if (!minGroupBySlots.contains(entry.getValue())
&& !requireOutput.containsAll(entry.getValue())) {
removeExpression.add(entry.getKey());
}
}

// eliminate outputs keys
// TODO: remove outputExprList computing
List<NamedExpression> outputExprList = new ArrayList<>();
for (int i = 0; i < agg.getOutputExpressions().size(); i++) {
if (rootExprsSet.contains(i)) {
outputExprList.add(agg.getOutputExpressions().get(i));
}
}
// find the remained outputExprs list
List<NamedExpression> remainedOutputExprList = new ArrayList<>();
for (int i = 0; i < agg.getOutputExpressions().size(); i++) {
NamedExpression outputExpr = agg.getOutputExpressions().get(i);
if (!agg.getGroupByExpressions().contains(outputExpr)) {
remainedOutputExprList.add(outputExpr);
}
List<Expression> newGroupExpression = new ArrayList<>();
for (Expression expression : agg.getGroupByExpressions()) {
if (!removeExpression.contains(expression)) {
newGroupExpression.add(expression);
}
outputExprList.addAll(remainedOutputExprList);
return new LogicalAggregate<>(resultExprs, agg.getOutputExpressions(), agg.child());
}).toRule(RuleType.ELIMINATE_GROUP_BY_KEY);
}

/**
* find the equal expr index from expr list.
*/
public int findEqualExpr(List<SlotReference> exprList, SlotReference expr) {
for (int i = 0; i < exprList.size(); i++) {
if (exprList.get(i).equals(expr)) {
return i;
}
List<NamedExpression> newOutput = new ArrayList<>();
for (NamedExpression expression : agg.getOutputExpressions()) {
if (!removeExpression.contains(expression)) {
newOutput.add(expression);
}
}
return -1;
}

private boolean isExprsContainFdParent(List<SlotReference> candiExprs, FdItem fdItem) {
return fdItem.getParentExprs().stream().allMatch(e -> candiExprs.contains(e));
return agg.withGroupByAndOutput(newGroupExpression, newOutput);
}
}
Loading

0 comments on commit 01fa2e6

Please sign in to comment.