Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor](Nereids) refactor fold constant framework on fe #40772

Merged
merged 4 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,4 @@
*/
String name();

/**
* args type
*/
String[] argTypes();

/**
* return type
*/
String returnType();

/**
* hasVarArgsc
*/
boolean varArgs() default false;
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
package org.apache.doris.nereids.trees.expressions;

import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeAcquire;
Expand All @@ -30,18 +28,17 @@
import org.apache.doris.nereids.trees.expressions.functions.executable.StringArithmetic;
import org.apache.doris.nereids.trees.expressions.functions.executable.TimeRoundSeries;
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.TypeCoercionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMultimap;

import java.lang.reflect.Array;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

Expand All @@ -52,7 +49,7 @@ public enum ExpressionEvaluator {

INSTANCE;

private ImmutableMultimap<String, FunctionInvoker> functions;
private ImmutableMultimap<String, Method> functions;

ExpressionEvaluator() {
registerFunctions();
Expand All @@ -68,23 +65,16 @@ public Expression eval(Expression expression) {
}

String fnName = null;
DataType[] args = null;
DataType ret = expression.getDataType();
if (expression instanceof BinaryArithmetic) {
BinaryArithmetic arithmetic = (BinaryArithmetic) expression;
fnName = arithmetic.getLegacyOperator().getName();
args = new DataType[]{arithmetic.left().getDataType(), arithmetic.right().getDataType()};
} else if (expression instanceof TimestampArithmetic) {
TimestampArithmetic arithmetic = (TimestampArithmetic) expression;
fnName = arithmetic.getFuncName();
args = new DataType[]{arithmetic.left().getDataType(), arithmetic.right().getDataType()};
} else if (expression instanceof BoundFunction) {
BoundFunction function = ((BoundFunction) expression);
fnName = function.getName();
args = new DataType[function.arity()];
for (int i = 0; i < function.children().size(); i++) {
args[i] = function.child(i).getDataType();
}
}

if ((Env.getCurrentEnv().isNullResultWithOneNullParamFunction(fnName))) {
Expand All @@ -95,22 +85,26 @@ public Expression eval(Expression expression) {
}
}

return invoke(expression, fnName, args);
return invoke(expression, fnName);
}

private Expression invoke(Expression expression, String fnName, DataType[] args) {
FunctionSignature signature = new FunctionSignature(fnName, args, null, false);
FunctionInvoker invoker = getFunction(signature);
if (invoker != null) {
private Expression invoke(Expression expression, String fnName) {
Method method = getFunction(fnName, expression.children());
if (method != null) {
try {
if (invoker.getSignature().hasVarArgs()) {
int fixedArgsSize = invoker.getSignature().getArgTypes().length - 1;
int totalSize = expression.children().size();
Class<?>[] parameterTypes = invoker.getMethod().getParameterTypes();
Class<?> parameterType = parameterTypes[parameterTypes.length - 1];
int varSize = method.getParameterTypes().length;
if (varSize == 0) {
return (Literal) method.invoke(null, expression.children().toArray());
}
boolean hasVarArgs = method.getParameterTypes()[varSize - 1].isArray();
if (hasVarArgs) {
int fixedArgsSize = varSize - 1;
int inputSize = expression.children().size();
Class<?>[] parameterTypes = method.getParameterTypes();
Class<?> parameterType = parameterTypes[varSize - 1];
Class<?> componentType = parameterType.getComponentType();
Object varArgs = Array.newInstance(componentType, totalSize - fixedArgsSize);
for (int i = fixedArgsSize; i < totalSize; i++) {
Object varArgs = Array.newInstance(componentType, inputSize - fixedArgsSize);
for (int i = fixedArgsSize; i < inputSize; i++) {
if (!(expression.children().get(i) instanceof NullLiteral)) {
Array.set(varArgs, i - fixedArgsSize, expression.children().get(i));
}
Expand All @@ -121,59 +115,70 @@ private Expression invoke(Expression expression, String fnName, DataType[] args)
}
objects[fixedArgsSize] = varArgs;

return invoker.invokeVars(objects);
return (Literal) method.invoke(null, varArgs);
}
return invoker.invoke(expression.children());
} catch (AnalysisException e) {
return (Literal) method.invoke(null, expression.children().toArray());
} catch (InvocationTargetException | IllegalAccessException | IllegalArgumentException e) {
return expression;
}
}
return expression;
}

private FunctionInvoker getFunction(FunctionSignature signature) {
Collection<FunctionInvoker> functionInvokers = functions.get(signature.getName());
for (FunctionInvoker candidate : functionInvokers) {
DataType[] candidateTypes = candidate.getSignature().getArgTypes();
DataType[] expectedTypes = signature.getArgTypes();
private boolean canDownCastTo(Class<?> expect, Class<?> input) {
if (DateLiteral.class.isAssignableFrom(expect)
|| DateTimeLiteral.class.isAssignableFrom(expect)) {
return expect.equals(input);
}
return expect.isAssignableFrom(input);
}

if (candidate.getSignature().hasVarArgs()) {
if (candidateTypes.length > expectedTypes.length) {
private Method getFunction(String fnName, List<Expression> inputs) {
Collection<Method> expectMethods = functions.get(fnName);
for (Method expect : expectMethods) {
boolean match = true;
int varSize = expect.getParameterTypes().length;
if (varSize == 0) {
if (inputs.size() == 0) {
return expect;
} else {
continue;
}
boolean match = true;
for (int i = 0; i < candidateTypes.length - 1; i++) {
if (!(expectedTypes[i].toCatalogDataType().matchesType(candidateTypes[i].toCatalogDataType()))) {
}
boolean hasVarArgs = expect.getParameterTypes()[varSize - 1].isArray();
if (hasVarArgs) {
int fixedArgsSize = varSize - 1;
int inputSize = inputs.size();
if (inputSize <= fixedArgsSize) {
continue;
}
Class<?>[] expectVarTypes = expect.getParameterTypes();
for (int i = 0; i < fixedArgsSize; i++) {
if (!canDownCastTo(expectVarTypes[i], inputs.get(i).getClass())) {
match = false;
break;
}
}
Type varType = candidateTypes[candidateTypes.length - 1].toCatalogDataType();
for (int i = candidateTypes.length - 1; i < expectedTypes.length; i++) {
if (!(expectedTypes[i].toCatalogDataType().matchesType(varType))) {
Class<?> varArgsType = expectVarTypes[varSize - 1];
Class<?> varArgType = varArgsType.getComponentType();
for (int i = fixedArgsSize; i < inputSize; i++) {
if (!canDownCastTo(varArgType, inputs.get(i).getClass())) {
match = false;
break;
}
}
if (match) {
return candidate;
} else {
} else {
int inputSize = inputs.size();
if (inputSize != varSize) {
continue;
}
}
if (candidateTypes.length != expectedTypes.length) {
continue;
}

boolean match = true;
for (int i = 0; i < candidateTypes.length; i++) {
if (!(expectedTypes[i].toCatalogDataType().matchesType(candidateTypes[i].toCatalogDataType()))) {
match = false;
break;
Class<?>[] expectVarTypes = expect.getParameterTypes();
for (int i = 0; i < varSize; i++) {
if (!canDownCastTo(expectVarTypes[i], inputs.get(i).getClass())) {
match = false;
}
}
}
if (match) {
return candidate;
return expect;
}
}
return null;
Expand All @@ -183,7 +188,7 @@ private void registerFunctions() {
if (functions != null) {
return;
}
ImmutableMultimap.Builder<String, FunctionInvoker> mapBuilder = new ImmutableMultimap.Builder<>();
ImmutableMultimap.Builder<String, Method> mapBuilder = new ImmutableMultimap.Builder<>();
List<Class<?>> classes = ImmutableList.of(
DateTimeAcquire.class,
DateTimeExtractAndTransform.class,
Expand All @@ -208,92 +213,10 @@ private void registerFunctions() {
this.functions = mapBuilder.build();
}

private void registerFEFunction(ImmutableMultimap.Builder<String, FunctionInvoker> mapBuilder,
private void registerFEFunction(ImmutableMultimap.Builder<String, Method> mapBuilder,
Method method, ExecFunction annotation) {
if (annotation != null) {
String name = annotation.name();
DataType returnType = DataType.convertFromString(annotation.returnType());
List<DataType> argTypes = new ArrayList<>();
for (String type : annotation.argTypes()) {
argTypes.add(TypeCoercionUtils.replaceDecimalV3WithWildcard(DataType.convertFromString(type)));
}
DataType[] array = new DataType[argTypes.size()];
for (int i = 0; i < argTypes.size(); i++) {
array[i] = argTypes.get(i);
}
FunctionSignature signature = new FunctionSignature(name, array, returnType, annotation.varArgs());
mapBuilder.put(name, new FunctionInvoker(method, signature));
}
}

/**
* function invoker.
*/
public static class FunctionInvoker {
private final Method method;
private final FunctionSignature signature;

public FunctionInvoker(Method method, FunctionSignature signature) {
this.method = method;
this.signature = signature;
}

public Method getMethod() {
return method;
}

public FunctionSignature getSignature() {
return signature;
}

public Literal invoke(List<Expression> args) throws AnalysisException {
try {
return (Literal) method.invoke(null, args.toArray());
} catch (InvocationTargetException | IllegalAccessException | IllegalArgumentException e) {
throw new AnalysisException(e.getLocalizedMessage());
}
}

public Literal invokeVars(Object[] args) throws AnalysisException {
try {
return (Literal) method.invoke(null, args);
} catch (InvocationTargetException | IllegalAccessException | IllegalArgumentException e) {
throw new AnalysisException(e.getLocalizedMessage());
}
mapBuilder.put(annotation.name(), method);
}
}

/**
* function signature.
*/
public static class FunctionSignature {
private final String name;
private final DataType[] argTypes;
private final DataType returnType;
private final boolean hasVarArgs;

public FunctionSignature(String name, DataType[] argTypes, DataType returnType, boolean hasVarArgs) {
this.name = name;
this.argTypes = argTypes;
this.returnType = returnType;
this.hasVarArgs = hasVarArgs;
}

public DataType[] getArgTypes() {
return argTypes;
}

public DataType getReturnType() {
return returnType;
}

public String getName() {
return name;
}

public boolean hasVarArgs() {
return hasVarArgs;
}
}

}
Loading
Loading