Skip to content

Commit

Permalink
[Refactor](Nereids) refactor fold constant framework on fe (#40772)
Browse files Browse the repository at this point in the history
change matching of function by notation to matching by inputs datatype
  • Loading branch information
LiBinfeng-01 authored Sep 13, 2024
1 parent 3cefe48 commit ffa9608
Show file tree
Hide file tree
Showing 9 changed files with 573 additions and 664 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,4 @@
*/
String name();

/**
* args type
*/
String[] argTypes();

/**
* return type
*/
String returnType();

/**
* hasVarArgsc
*/
boolean varArgs() default false;
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
package org.apache.doris.nereids.trees.expressions;

import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeAcquire;
Expand All @@ -30,18 +28,17 @@
import org.apache.doris.nereids.trees.expressions.functions.executable.StringArithmetic;
import org.apache.doris.nereids.trees.expressions.functions.executable.TimeRoundSeries;
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.TypeCoercionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMultimap;

import java.lang.reflect.Array;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

Expand All @@ -52,7 +49,7 @@ public enum ExpressionEvaluator {

INSTANCE;

private ImmutableMultimap<String, FunctionInvoker> functions;
private ImmutableMultimap<String, Method> functions;

ExpressionEvaluator() {
registerFunctions();
Expand All @@ -68,23 +65,16 @@ public Expression eval(Expression expression) {
}

String fnName = null;
DataType[] args = null;
DataType ret = expression.getDataType();
if (expression instanceof BinaryArithmetic) {
BinaryArithmetic arithmetic = (BinaryArithmetic) expression;
fnName = arithmetic.getLegacyOperator().getName();
args = new DataType[]{arithmetic.left().getDataType(), arithmetic.right().getDataType()};
} else if (expression instanceof TimestampArithmetic) {
TimestampArithmetic arithmetic = (TimestampArithmetic) expression;
fnName = arithmetic.getFuncName();
args = new DataType[]{arithmetic.left().getDataType(), arithmetic.right().getDataType()};
} else if (expression instanceof BoundFunction) {
BoundFunction function = ((BoundFunction) expression);
fnName = function.getName();
args = new DataType[function.arity()];
for (int i = 0; i < function.children().size(); i++) {
args[i] = function.child(i).getDataType();
}
}

if ((Env.getCurrentEnv().isNullResultWithOneNullParamFunction(fnName))) {
Expand All @@ -95,22 +85,26 @@ public Expression eval(Expression expression) {
}
}

return invoke(expression, fnName, args);
return invoke(expression, fnName);
}

private Expression invoke(Expression expression, String fnName, DataType[] args) {
FunctionSignature signature = new FunctionSignature(fnName, args, null, false);
FunctionInvoker invoker = getFunction(signature);
if (invoker != null) {
private Expression invoke(Expression expression, String fnName) {
Method method = getFunction(fnName, expression.children());
if (method != null) {
try {
if (invoker.getSignature().hasVarArgs()) {
int fixedArgsSize = invoker.getSignature().getArgTypes().length - 1;
int totalSize = expression.children().size();
Class<?>[] parameterTypes = invoker.getMethod().getParameterTypes();
Class<?> parameterType = parameterTypes[parameterTypes.length - 1];
int varSize = method.getParameterTypes().length;
if (varSize == 0) {
return (Literal) method.invoke(null, expression.children().toArray());
}
boolean hasVarArgs = method.getParameterTypes()[varSize - 1].isArray();
if (hasVarArgs) {
int fixedArgsSize = varSize - 1;
int inputSize = expression.children().size();
Class<?>[] parameterTypes = method.getParameterTypes();
Class<?> parameterType = parameterTypes[varSize - 1];
Class<?> componentType = parameterType.getComponentType();
Object varArgs = Array.newInstance(componentType, totalSize - fixedArgsSize);
for (int i = fixedArgsSize; i < totalSize; i++) {
Object varArgs = Array.newInstance(componentType, inputSize - fixedArgsSize);
for (int i = fixedArgsSize; i < inputSize; i++) {
if (!(expression.children().get(i) instanceof NullLiteral)) {
Array.set(varArgs, i - fixedArgsSize, expression.children().get(i));
}
Expand All @@ -121,59 +115,70 @@ private Expression invoke(Expression expression, String fnName, DataType[] args)
}
objects[fixedArgsSize] = varArgs;

return invoker.invokeVars(objects);
return (Literal) method.invoke(null, varArgs);
}
return invoker.invoke(expression.children());
} catch (AnalysisException e) {
return (Literal) method.invoke(null, expression.children().toArray());
} catch (InvocationTargetException | IllegalAccessException | IllegalArgumentException e) {
return expression;
}
}
return expression;
}

private FunctionInvoker getFunction(FunctionSignature signature) {
Collection<FunctionInvoker> functionInvokers = functions.get(signature.getName());
for (FunctionInvoker candidate : functionInvokers) {
DataType[] candidateTypes = candidate.getSignature().getArgTypes();
DataType[] expectedTypes = signature.getArgTypes();
private boolean canDownCastTo(Class<?> expect, Class<?> input) {
if (DateLiteral.class.isAssignableFrom(expect)
|| DateTimeLiteral.class.isAssignableFrom(expect)) {
return expect.equals(input);
}
return expect.isAssignableFrom(input);
}

if (candidate.getSignature().hasVarArgs()) {
if (candidateTypes.length > expectedTypes.length) {
private Method getFunction(String fnName, List<Expression> inputs) {
Collection<Method> expectMethods = functions.get(fnName);
for (Method expect : expectMethods) {
boolean match = true;
int varSize = expect.getParameterTypes().length;
if (varSize == 0) {
if (inputs.size() == 0) {
return expect;
} else {
continue;
}
boolean match = true;
for (int i = 0; i < candidateTypes.length - 1; i++) {
if (!(expectedTypes[i].toCatalogDataType().matchesType(candidateTypes[i].toCatalogDataType()))) {
}
boolean hasVarArgs = expect.getParameterTypes()[varSize - 1].isArray();
if (hasVarArgs) {
int fixedArgsSize = varSize - 1;
int inputSize = inputs.size();
if (inputSize <= fixedArgsSize) {
continue;
}
Class<?>[] expectVarTypes = expect.getParameterTypes();
for (int i = 0; i < fixedArgsSize; i++) {
if (!canDownCastTo(expectVarTypes[i], inputs.get(i).getClass())) {
match = false;
break;
}
}
Type varType = candidateTypes[candidateTypes.length - 1].toCatalogDataType();
for (int i = candidateTypes.length - 1; i < expectedTypes.length; i++) {
if (!(expectedTypes[i].toCatalogDataType().matchesType(varType))) {
Class<?> varArgsType = expectVarTypes[varSize - 1];
Class<?> varArgType = varArgsType.getComponentType();
for (int i = fixedArgsSize; i < inputSize; i++) {
if (!canDownCastTo(varArgType, inputs.get(i).getClass())) {
match = false;
break;
}
}
if (match) {
return candidate;
} else {
} else {
int inputSize = inputs.size();
if (inputSize != varSize) {
continue;
}
}
if (candidateTypes.length != expectedTypes.length) {
continue;
}

boolean match = true;
for (int i = 0; i < candidateTypes.length; i++) {
if (!(expectedTypes[i].toCatalogDataType().matchesType(candidateTypes[i].toCatalogDataType()))) {
match = false;
break;
Class<?>[] expectVarTypes = expect.getParameterTypes();
for (int i = 0; i < varSize; i++) {
if (!canDownCastTo(expectVarTypes[i], inputs.get(i).getClass())) {
match = false;
}
}
}
if (match) {
return candidate;
return expect;
}
}
return null;
Expand All @@ -183,7 +188,7 @@ private void registerFunctions() {
if (functions != null) {
return;
}
ImmutableMultimap.Builder<String, FunctionInvoker> mapBuilder = new ImmutableMultimap.Builder<>();
ImmutableMultimap.Builder<String, Method> mapBuilder = new ImmutableMultimap.Builder<>();
List<Class<?>> classes = ImmutableList.of(
DateTimeAcquire.class,
DateTimeExtractAndTransform.class,
Expand All @@ -208,92 +213,10 @@ private void registerFunctions() {
this.functions = mapBuilder.build();
}

private void registerFEFunction(ImmutableMultimap.Builder<String, FunctionInvoker> mapBuilder,
private void registerFEFunction(ImmutableMultimap.Builder<String, Method> mapBuilder,
Method method, ExecFunction annotation) {
if (annotation != null) {
String name = annotation.name();
DataType returnType = DataType.convertFromString(annotation.returnType());
List<DataType> argTypes = new ArrayList<>();
for (String type : annotation.argTypes()) {
argTypes.add(TypeCoercionUtils.replaceDecimalV3WithWildcard(DataType.convertFromString(type)));
}
DataType[] array = new DataType[argTypes.size()];
for (int i = 0; i < argTypes.size(); i++) {
array[i] = argTypes.get(i);
}
FunctionSignature signature = new FunctionSignature(name, array, returnType, annotation.varArgs());
mapBuilder.put(name, new FunctionInvoker(method, signature));
}
}

/**
* function invoker.
*/
public static class FunctionInvoker {
private final Method method;
private final FunctionSignature signature;

public FunctionInvoker(Method method, FunctionSignature signature) {
this.method = method;
this.signature = signature;
}

public Method getMethod() {
return method;
}

public FunctionSignature getSignature() {
return signature;
}

public Literal invoke(List<Expression> args) throws AnalysisException {
try {
return (Literal) method.invoke(null, args.toArray());
} catch (InvocationTargetException | IllegalAccessException | IllegalArgumentException e) {
throw new AnalysisException(e.getLocalizedMessage());
}
}

public Literal invokeVars(Object[] args) throws AnalysisException {
try {
return (Literal) method.invoke(null, args);
} catch (InvocationTargetException | IllegalAccessException | IllegalArgumentException e) {
throw new AnalysisException(e.getLocalizedMessage());
}
mapBuilder.put(annotation.name(), method);
}
}

/**
* function signature.
*/
public static class FunctionSignature {
private final String name;
private final DataType[] argTypes;
private final DataType returnType;
private final boolean hasVarArgs;

public FunctionSignature(String name, DataType[] argTypes, DataType returnType, boolean hasVarArgs) {
this.name = name;
this.argTypes = argTypes;
this.returnType = returnType;
this.hasVarArgs = hasVarArgs;
}

public DataType[] getArgTypes() {
return argTypes;
}

public DataType getReturnType() {
return returnType;
}

public String getName() {
return name;
}

public boolean hasVarArgs() {
return hasVarArgs;
}
}

}
Loading

0 comments on commit ffa9608

Please sign in to comment.