Skip to content

Commit

Permalink
[Refactor](Nereids) refactor fold constant framework on fe
Browse files Browse the repository at this point in the history
  • Loading branch information
LiBinfeng-01 committed Sep 13, 2024
1 parent 4466541 commit 1cb1888
Show file tree
Hide file tree
Showing 9 changed files with 541 additions and 577 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,6 @@
*/
String name();

/**
* args type
*/
String[] argTypes();

/**
* return type
*/
String returnType();

/**
* hasVarArgsc
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.expressions;

import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
Expand All @@ -33,15 +32,13 @@
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.TypeCoercionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMultimap;

import java.lang.reflect.Array;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

Expand Down Expand Up @@ -99,12 +96,11 @@ public Expression eval(Expression expression) {
}

private Expression invoke(Expression expression, String fnName, DataType[] args) {
FunctionSignature signature = new FunctionSignature(fnName, args, null, false);
FunctionInvoker invoker = getFunction(signature);
FunctionInvoker invoker = getFunction(fnName, expression.children());
if (invoker != null) {
try {
if (invoker.getSignature().hasVarArgs()) {
int fixedArgsSize = invoker.getSignature().getArgTypes().length - 1;
int fixedArgsSize = invoker.getMethod().getParameterTypes().length - 1;
int totalSize = expression.children().size();
Class<?>[] parameterTypes = invoker.getMethod().getParameterTypes();
Class<?> parameterType = parameterTypes[parameterTypes.length - 1];
Expand All @@ -131,49 +127,48 @@ private Expression invoke(Expression expression, String fnName, DataType[] args)
return expression;
}

private FunctionInvoker getFunction(FunctionSignature signature) {
Collection<FunctionInvoker> functionInvokers = functions.get(signature.getName());
for (FunctionInvoker candidate : functionInvokers) {
DataType[] candidateTypes = candidate.getSignature().getArgTypes();
DataType[] expectedTypes = signature.getArgTypes();
private boolean canDownCastTo(Class<?> expect, Class<?> input) {
return expect.isAssignableFrom(input);
}

if (candidate.getSignature().hasVarArgs()) {
if (candidateTypes.length > expectedTypes.length) {
private FunctionInvoker getFunction(String fnName, List<Expression> inputs) {
Collection<FunctionInvoker> functionInvokers = functions.get(fnName);
for (FunctionInvoker expect : functionInvokers) {
boolean match = true;
if (expect.getSignature().hasVarArgs()) {
int fixedArgsSize = expect.getMethod().getParameterTypes().length - 1;
int inputSize = inputs.size();
if (inputSize <= fixedArgsSize) {
continue;
}
boolean match = true;
for (int i = 0; i < candidateTypes.length - 1; i++) {
if (!(expectedTypes[i].toCatalogDataType().matchesType(candidateTypes[i].toCatalogDataType()))) {
Class<?>[] expectVarTypes = expect.getMethod().getParameterTypes();
for (int i = 0; i < fixedArgsSize; i++) {
if (!canDownCastTo(expectVarTypes[i], inputs.get(i).getClass())) {
match = false;
break;
}
}
Type varType = candidateTypes[candidateTypes.length - 1].toCatalogDataType();
for (int i = candidateTypes.length - 1; i < expectedTypes.length; i++) {
if (!(expectedTypes[i].toCatalogDataType().matchesType(varType))) {
Class<?> varArgsType = expectVarTypes[expectVarTypes.length - 1];
Class<?> varArgType = varArgsType.getComponentType();
for (int i = fixedArgsSize; i < inputSize; i++) {
if (!canDownCastTo(varArgType, inputs.get(i).getClass())) {
match = false;
break;
}
}
if (match) {
return candidate;
} else {
} else {
int fixedArgsSize = expect.getMethod().getParameterTypes().length;
int inputSize = inputs.size();
if (inputSize != fixedArgsSize) {
continue;
}
}
if (candidateTypes.length != expectedTypes.length) {
continue;
}

boolean match = true;
for (int i = 0; i < candidateTypes.length; i++) {
if (!(expectedTypes[i].toCatalogDataType().matchesType(candidateTypes[i].toCatalogDataType()))) {
match = false;
break;
Class<?>[] expectVarTypes = expect.getMethod().getParameterTypes();
for (int i = 0; i < fixedArgsSize; i++) {
if (!canDownCastTo(expectVarTypes[i], inputs.get(i).getClass())) {
match = false;
}
}
}
if (match) {
return candidate;
return expect;
}
}
return null;
Expand Down Expand Up @@ -212,16 +207,7 @@ private void registerFEFunction(ImmutableMultimap.Builder<String, FunctionInvoke
Method method, ExecFunction annotation) {
if (annotation != null) {
String name = annotation.name();
DataType returnType = DataType.convertFromString(annotation.returnType());
List<DataType> argTypes = new ArrayList<>();
for (String type : annotation.argTypes()) {
argTypes.add(TypeCoercionUtils.replaceDecimalV3WithWildcard(DataType.convertFromString(type)));
}
DataType[] array = new DataType[argTypes.size()];
for (int i = 0; i < argTypes.size(); i++) {
array[i] = argTypes.get(i);
}
FunctionSignature signature = new FunctionSignature(name, array, returnType, annotation.varArgs());
FunctionSignature signature = new FunctionSignature(name, annotation.varArgs());
mapBuilder.put(name, new FunctionInvoker(method, signature));
}
}
Expand Down Expand Up @@ -268,25 +254,13 @@ public Literal invokeVars(Object[] args) throws AnalysisException {
*/
public static class FunctionSignature {
private final String name;
private final DataType[] argTypes;
private final DataType returnType;
private final boolean hasVarArgs;

public FunctionSignature(String name, DataType[] argTypes, DataType returnType, boolean hasVarArgs) {
public FunctionSignature(String name, boolean hasVarArgs) {
this.name = name;
this.argTypes = argTypes;
this.returnType = returnType;
this.hasVarArgs = hasVarArgs;
}

public DataType[] getArgTypes() {
return argTypes;
}

public DataType getReturnType() {
return returnType;
}

public String getName() {
return name;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ public class DateTimeAcquire {
/**
* date acquire function: now
*/
@ExecFunction(name = "now", argTypes = {}, returnType = "DATETIME")
@ExecFunction(name = "now")
public static Expression now() {
return DateTimeLiteral.fromJavaDateType(LocalDateTime.now(DateUtils.getTimeZone()));
}

@ExecFunction(name = "now", argTypes = {"INT"}, returnType = "DATETIMEV2")
@ExecFunction(name = "now")
public static Expression now(IntegerLiteral precision) {
return DateTimeV2Literal.fromJavaDateType(LocalDateTime.now(DateUtils.getTimeZone()),
precision.getValue());
Expand All @@ -50,38 +50,38 @@ public static Expression now(IntegerLiteral precision) {
/**
* date acquire function: current_timestamp
*/
@ExecFunction(name = "current_timestamp", argTypes = {}, returnType = "DATETIME")
@ExecFunction(name = "current_timestamp")
public static Expression currentTimestamp() {
return DateTimeLiteral.fromJavaDateType(LocalDateTime.now(DateUtils.getTimeZone()));
}

@ExecFunction(name = "current_timestamp", argTypes = {"INT"}, returnType = "DATETIMEV2")
@ExecFunction(name = "current_timestamp")
public static Expression currentTimestamp(IntegerLiteral precision) {
return DateTimeV2Literal.fromJavaDateType(LocalDateTime.now(DateUtils.getTimeZone()), precision.getValue());
}

/**
* date acquire function: localtime/localtimestamp
*/
@ExecFunction(name = "localtime", argTypes = {}, returnType = "DATETIME")
@ExecFunction(name = "localtime")
public static Expression localTime() {
return DateTimeLiteral.fromJavaDateType(LocalDateTime.now(DateUtils.getTimeZone()));
}

@ExecFunction(name = "localtimestamp", argTypes = {}, returnType = "DATETIME")
@ExecFunction(name = "localtimestamp")
public static Expression localTimestamp() {
return DateTimeV2Literal.fromJavaDateType(LocalDateTime.now(DateUtils.getTimeZone()));
}

/**
* date acquire function: current_date
*/
@ExecFunction(name = "curdate", argTypes = {}, returnType = "DATE")
@ExecFunction(name = "curdate")
public static Expression curDate() {
return DateLiteral.fromJavaDateType(LocalDateTime.now(DateUtils.getTimeZone()));
}

@ExecFunction(name = "current_date", argTypes = {}, returnType = "DATE")
@ExecFunction(name = "current_date")
public static Expression currentDate() {
return DateLiteral.fromJavaDateType(LocalDateTime.now(DateUtils.getTimeZone()));
}
Expand All @@ -90,28 +90,28 @@ public static Expression currentDate() {
// /**
// * date acquire function: current_time
// */
// @ExecFunction(name = "curtime", argTypes = {}, returnType = "TIME")
// @ExecFunction(name = "curtime")
// public static Expression curTime() {
// return DateTimeLiteral.fromJavaDateType(LocalDateTime.now(DateUtils.getTimeZone()));
// }

// @ExecFunction(name = "current_time", argTypes = {}, returnType = "TIME")
// @ExecFunction(name = "current_time")
// public static Expression currentTime() {
// return DateTimeLiteral.fromJavaDateType(LocalDateTime.now(DateUtils.getTimeZone()));
// }

/**
* date transformation function: unix_timestamp
*/
@ExecFunction(name = "unix_timestamp", argTypes = {}, returnType = "INT")
@ExecFunction(name = "unix_timestamp")
public static Expression unixTimestamp() {
return new IntegerLiteral((int) (System.currentTimeMillis() / 1000L));
}

/**
* date transformation function: utc_timestamp
*/
@ExecFunction(name = "utc_timestamp", argTypes = {}, returnType = "INT")
@ExecFunction(name = "utc_timestamp")
public static Expression utcTimestamp() {
return DateTimeLiteral.fromJavaDateType(LocalDateTime.now(ZoneId.of("UTC+0")));
}
Expand Down
Loading

0 comments on commit 1cb1888

Please sign in to comment.