Skip to content

Commit

Permalink
[opt](Nereids) support search from override udfs with same arity (#40432
Browse files Browse the repository at this point in the history
)

create alias function f1(int) with parameter(id) as abs(id);
create alias function f1(string) with parameter(id) as substr(id, 2);
select f1('1'); -- bind on f1(string)
select f1(1);   -- bind on f1(int)

test case already existed in P0
  • Loading branch information
morrySnow authored Sep 5, 2024
1 parent 1666c83 commit 82c2650
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.udf.UdfBuilder;
import org.apache.doris.nereids.types.DataType;
Expand Down Expand Up @@ -156,18 +158,33 @@ public FunctionBuilder findFunctionBuilder(String dbName, String name, List<?> a
+ "' which has " + arity + " arity. Candidate functions are: " + candidateHints);
}
if (candidateBuilders.size() > 1) {
String candidateHints = getCandidateHint(name, candidateBuilders);
// TODO: NereidsPlanner not supported override function by the same arity, we will support it later
if (ConnectContext.get() != null) {
try {
ConnectContext.get().getSessionVariable().enableFallbackToOriginalPlannerOnce();
} catch (Throwable t) {
// ignore error
boolean needChooseOne = true;
List<FunctionSignature> signatures = Lists.newArrayListWithCapacity(candidateBuilders.size());
for (FunctionBuilder functionBuilder : candidateBuilders) {
if (functionBuilder instanceof UdfBuilder) {
signatures.addAll(((UdfBuilder) functionBuilder).getSignatures());
} else {
needChooseOne = false;
break;
}
}
for (Object argument : arguments) {
if (!(argument instanceof Expression)) {
needChooseOne = false;
break;
}
}
if (needChooseOne) {
FunctionSignature signature = new UdfSignatureSearcher(signatures, (List) arguments).getSignature();
for (int i = 0; i < signatures.size(); i++) {
if (signatures.get(i).equals(signature)) {
return candidateBuilders.get(i);
}
}
}
String candidateHints = getCandidateHint(name, candidateBuilders);
throw new AnalysisException("Function '" + qualifiedName + "' is ambiguous: " + candidateHints);
}

return candidateBuilders.get(0);
}

Expand Down Expand Up @@ -235,4 +252,63 @@ public void dropUdf(String dbName, String name, List<DataType> argTypes) {
.removeIf(builder -> ((UdfBuilder) builder).getArgTypes().equals(argTypes));
}
}

/**
* use for search appropriate signature for UDFs if candidate more than one.
*/
static class UdfSignatureSearcher implements ExplicitlyCastableSignature {

private final List<FunctionSignature> signatures;
private final List<Expression> arguments;

public UdfSignatureSearcher(List<FunctionSignature> signatures, List<Expression> arguments) {
this.signatures = signatures;
this.arguments = arguments;
}

@Override
public List<FunctionSignature> getSignatures() {
return signatures;
}

@Override
public FunctionSignature getSignature() {
return searchSignature(signatures);
}

@Override
public boolean nullable() {
throw new AnalysisException("could not call nullable on UdfSignatureSearcher");
}

@Override
public List<Expression> children() {
return arguments;
}

@Override
public Expression child(int index) {
return arguments.get(index);
}

@Override
public int arity() {
return arguments.size();
}

@Override
public <T> Optional<T> getMutableState(String key) {
return Optional.empty();
}

@Override
public void setMutableState(String key, Object value) {
}

@Override
public Expression withChildren(List<Expression> children) {
throw new AnalysisException("could not call withChildren on UdfSignatureSearcher");

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.BiFunction;

public class FunctionSignature {
public final DataType returnType;
Expand Down Expand Up @@ -78,21 +77,6 @@ public FunctionSignature withArgumentTypes(boolean hasVarArgs, List<? extends Da
return new FunctionSignature(returnType, hasVarArgs, argumentsTypes);
}

/**
* change argument type by the signature's type and the corresponding argument's type
* @param arguments arguments
* @param transform param1: signature's type, param2: argument's type, return new type you want to change
* @return
*/
public FunctionSignature withArgumentTypes(List<Expression> arguments,
BiFunction<DataType, Expression, DataType> transform) {
List<DataType> newTypes = Lists.newArrayList();
for (int i = 0; i < arguments.size(); i++) {
newTypes.add(transform.apply(getArgType(i), arguments.get(i)));
}
return withArgumentTypes(hasVarArgs, newTypes);
}

/**
* change argument type by the signature's type and the corresponding argument's type
* @param arguments arguments
Expand Down Expand Up @@ -145,6 +129,24 @@ public String toString() {
.toString();
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
FunctionSignature signature = (FunctionSignature) o;
return hasVarArgs == signature.hasVarArgs && arity == signature.arity && Objects.equals(returnType,
signature.returnType) && Objects.equals(argumentsTypes, signature.argumentsTypes);
}

@Override
public int hashCode() {
return Objects.hash(returnType, hasVarArgs, argumentsTypes, arity);
}

public static class FuncSigBuilder {
public final DataType returnType;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.NullType;

import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;

Expand Down Expand Up @@ -62,8 +61,7 @@ public AliasUdf(String name, List<DataType> argTypes, UnboundFunction unboundFun

@Override
public List<FunctionSignature> getSignatures() {
return ImmutableList.of(Suppliers.memoize(() -> FunctionSignature
.of(NullType.INSTANCE, argTypes)).get());
return ImmutableList.of(FunctionSignature.of(NullType.INSTANCE, argTypes));
}

public List<String> getParameters() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.doris.nereids.trees.expressions.functions.udf;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.common.Pair;
import org.apache.doris.common.util.ReflectionUtils;
import org.apache.doris.nereids.analyzer.Scope;
Expand All @@ -25,7 +26,6 @@
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.TypeCoercionUtils;

Expand Down Expand Up @@ -53,6 +53,11 @@ public List<DataType> getArgTypes() {
return aliasUdf.getArgTypes();
}

@Override
public List<FunctionSignature> getSignatures() {
return aliasUdf.getSignatures();
}

@Override
public Class<? extends BoundFunction> functionClass() {
return AliasUdf.class;
Expand Down Expand Up @@ -109,17 +114,4 @@ public Expression visitSlotReference(SlotReference slotReference, ExpressionRewr

return Pair.of(udfAnalyzer.analyze(aliasUdf.getUnboundFunction()), boundAliasFunction);
}

private static class SlotReplacer extends DefaultExpressionRewriter<Map<SlotReference, Expression>> {
public static final SlotReplacer INSTANCE = new SlotReplacer();

public Expression replace(Expression expression, Map<SlotReference, Expression> context) {
return expression.accept(this, context);
}

@Override
public Expression visitSlotReference(SlotReference slot, Map<SlotReference, Expression> context) {
return context.get(slot);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.doris.nereids.trees.expressions.functions.udf;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.common.Pair;
import org.apache.doris.common.util.ReflectionUtils;
import org.apache.doris.nereids.trees.expressions.Expression;
Expand Down Expand Up @@ -50,6 +51,11 @@ public List<DataType> getArgTypes() {
.collect(Collectors.toList())).get();
}

@Override
public List<FunctionSignature> getSignatures() {
return udaf.getSignatures();
}

@Override
public Class<? extends BoundFunction> functionClass() {
return JavaUdaf.class;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.doris.nereids.trees.expressions.functions.udf;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.common.Pair;
import org.apache.doris.common.util.ReflectionUtils;
import org.apache.doris.nereids.trees.expressions.Expression;
Expand Down Expand Up @@ -52,6 +53,11 @@ public List<DataType> getArgTypes() {
.collect(Collectors.toList())).get();
}

@Override
public List<FunctionSignature> getSignatures() {
return udf.getSignatures();
}

@Override
public Class<? extends BoundFunction> functionClass() {
return JavaUdf.class;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.doris.nereids.trees.expressions.functions.udf;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.common.Pair;
import org.apache.doris.common.util.ReflectionUtils;
import org.apache.doris.nereids.trees.expressions.Expression;
Expand Down Expand Up @@ -52,6 +53,11 @@ public List<DataType> getArgTypes() {
.collect(Collectors.toList())).get();
}

@Override
public List<FunctionSignature> getSignatures() {
return udf.getSignatures();
}

@Override
public Class<? extends BoundFunction> functionClass() {
return JavaUdtf.class;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.doris.nereids.trees.expressions.functions.udf;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.types.DataType;

Expand All @@ -27,4 +28,6 @@
*/
public abstract class UdfBuilder extends FunctionBuilder {
public abstract List<DataType> getArgTypes();

public abstract List<FunctionSignature> getSignatures();
}

0 comments on commit 82c2650

Please sign in to comment.