Skip to content

Commit

Permalink
Check arg type against Py UDF signature at query compile time (#5254)
Browse files Browse the repository at this point in the history
* WIP check arg type against defs at parsing time

* Fix spotless check errors

* Fix bugs and old test cases

* Refactor parsing code of UDF signatures

* Refactoring interface

* Code clean up/add more comments/test cases

* Fix a test failure

* Skip vermin check for explictly ver checked code

* More novermin check

* Make comment more clear and new test case
  • Loading branch information
jmao-denver committed Mar 25, 2024
1 parent 4089d2b commit bcbeff7
Show file tree
Hide file tree
Showing 12 changed files with 434 additions and 163 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ public final class QueryLanguageParser extends GenericVisitorAdapter<Class<?>, Q
private final Map<String, Class<?>> staticImportLookupCache = new HashMap<>();

// We need some class to represent null. We know for certain that this one won't be used...
private static final Class<?> NULL_CLASS = QueryLanguageParser.class;
public static final Class<?> NULL_CLASS = QueryLanguageParser.class;

/**
* The result of the QueryLanguageParser for the expression passed given to the constructor.
Expand Down Expand Up @@ -1939,7 +1939,7 @@ private static boolean isAssociativitySafeExpression(Expression expr) {
* @return {@code true} if a conversion from {@code original} to {@code target} is a widening conversion; otherwise,
* {@code false}.
*/
static boolean isWideningPrimitiveConversion(Class<?> original, Class<?> target) {
public static boolean isWideningPrimitiveConversion(Class<?> original, Class<?> target) {
if (original == null || !original.isPrimitive() || target == null || !target.isPrimitive()
|| original.equals(void.class) || target.equals(void.class)) {
throw new IllegalArgumentException("Arguments must be a primitive type (excluding void)!");
Expand Down Expand Up @@ -1968,6 +1968,7 @@ static boolean isWideningPrimitiveConversion(Class<?> original, Class<?> target)
return false;
}


private enum LanguageParserPrimitiveType {
// Including "Enum" (or really, any differentiating string) in these names is important. They're used
// in a switch() statement, which apparently does not support qualified names. And we can't use
Expand Down Expand Up @@ -2498,13 +2499,37 @@ public Class<?> visit(MethodCallExpr n, VisitArgs printer) {

// Attempt python function call vectorization.
if (scopeType != null && PyCallableWrapper.class.isAssignableFrom(scopeType)) {
verifyPyCallableArguments(n, argTypes);
tryVectorizePythonCallable(n, scopeType, convertedArgExpressions, argTypes);
}

return calculateMethodReturnTypeUsingGenerics(scopeType, n.getScope().orElse(null), method, expressionTypes,
typeArguments);
}

private void verifyPyCallableArguments(@NotNull MethodCallExpr n, @NotNull Class<?>[] argTypes) {
final String invokedMethodName = n.getNameAsString();

if (GET_ATTRIBUTE_METHOD_NAME.equals(invokedMethodName)) {
// Currently Python UDF handling is only supported for top module level function(callable) calls.
// The getAttribute() calls which is needed to support Python method calls, which is beyond the scope of
// current implementation. So we are skipping the argument verification for getAttribute() calls.
return;
}
if (!n.containsData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS)) {
return;
}
final PyCallableDetails pyCallableDetails = n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS);
final String pyMethodName = pyCallableDetails.pythonMethodName;
final Object methodVar = queryScopeVariables.get(pyMethodName);
if (!(methodVar instanceof PyCallableWrapper)) {
return;
}
final PyCallableWrapper pyCallableWrapper = (PyCallableWrapper) methodVar;
pyCallableWrapper.parseSignature();
pyCallableWrapper.verifyArguments(argTypes);
}

private Optional<CastExpr> makeCastExpressionForPyCallable(Class<?> retType, MethodCallExpr callMethodCall) {
if (retType.isPrimitive()) {
return Optional.of(new CastExpr(
Expand Down Expand Up @@ -2552,7 +2577,7 @@ private Optional<Class<?>> pyCallableReturnType(@NotNull MethodCallExpr n) {
}
final PyCallableWrapper pyCallableWrapper = (PyCallableWrapper) paramValueRaw;
pyCallableWrapper.parseSignature();
return Optional.ofNullable(pyCallableWrapper.getReturnType());
return Optional.ofNullable(pyCallableWrapper.getSignature().getReturnType());
}

@NotNull
Expand Down Expand Up @@ -2683,7 +2708,8 @@ private void checkVectorizability(@NotNull final MethodCallExpr n,
pyCallableWrapper.parseSignature();
if (!pyCallableWrapper.isVectorizableReturnType()) {
throw new PythonCallVectorizationFailure(
"Python function return type is not supported: " + pyCallableWrapper.getReturnType());
"Python function return type is not supported: "
+ pyCallableWrapper.getSignature().getReturnType());
}

// Python vectorized functions(numba, DH) return arrays of primitive/Object types. This will break the generated
Expand Down Expand Up @@ -2726,11 +2752,10 @@ private void checkVectorizability(@NotNull final MethodCallExpr n,
}
}

List<Class<?>> paramTypes = pyCallableWrapper.getParamTypes();
if (paramTypes.size() != expressions.length) {
if (pyCallableWrapper.getSignature().getParameters().size() != expressions.length) {
// note vectorization doesn't handle Python variadic arguments
throw new PythonCallVectorizationFailure("Python function argument count mismatch: " + n + " "
+ paramTypes.size() + " vs. " + expressions.length);
+ pyCallableWrapper.getSignature().getParameters().size() + " vs. " + expressions.length);
}
}

Expand All @@ -2739,10 +2764,9 @@ private void prepareVectorizationArgs(
Expression[] expressions,
Class<?>[] argTypes,
PyCallableWrapper pyCallableWrapper) {
List<Class<?>> paramTypes = pyCallableWrapper.getParamTypes();
if (paramTypes.size() != expressions.length) {
if (pyCallableWrapper.getSignature().getParameters().size() != expressions.length) {
throw new PythonCallVectorizationFailure("Python function argument count mismatch: " + n + " "
+ paramTypes.size() + " vs. " + expressions.length);
+ pyCallableWrapper.getSignature().getParameters().size() + " vs. " + expressions.length);
}

pyCallableWrapper.initializeChunkArguments();
Expand All @@ -2763,11 +2787,6 @@ private void prepareVectorizationArgs(
} else {
throw new IllegalStateException("Vectorizability check failed: " + n);
}

if (!isSafelyCoerceable(argTypes[i], paramTypes.get(i))) {
throw new PythonCallVectorizationFailure("Python vectorized function argument type mismatch: " + n + " "
+ argTypes[i].getSimpleName() + " -> " + paramTypes.get(i).getSimpleName());
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ private void checkAndInitializeVectorization(QueryLanguageParser.Result result,
final PyCallableWrapperJpyImpl pyCallableWrapper = cws[0];

if (pyCallableWrapper.isVectorizable()) {
checkReturnType(result, pyCallableWrapper.getReturnType());
checkReturnType(result, pyCallableWrapper.getSignature().getReturnType());

for (String variable : result.getVariablesUsed()) {
if (variable.equals("i")) {
Expand All @@ -284,7 +284,8 @@ private void checkAndInitializeVectorization(QueryLanguageParser.Result result,
ArgumentsChunked argumentsChunked = pyCallableWrapper.buildArgumentsChunked(usedColumns);
PyObject vectorized = pyCallableWrapper.vectorizedCallable();
DeephavenCompatibleFunction dcf = DeephavenCompatibleFunction.create(vectorized,
pyCallableWrapper.getReturnType(), usedColumns.toArray(new String[0]), argumentsChunked, true);
pyCallableWrapper.getSignature().getReturnType(), usedColumns.toArray(new String[0]),
argumentsChunked, true);
setFilter(new ConditionFilter.ChunkFilter(
dcf.toFilterKernel(),
dcf.getColumnNames().toArray(new String[0]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ private void checkAndInitializeVectorization(Map<String, ColumnDefinition<?>> co
PyObject vectorized = pyCallableWrapper.vectorizedCallable();
formulaColumnPython = FormulaColumnPython.create(this.columnName,
DeephavenCompatibleFunction.create(vectorized,
pyCallableWrapper.getReturnType(), this.analyzedFormula.sourceDescriptor.sources,
pyCallableWrapper.getSignature().getReturnType(),
this.analyzedFormula.sourceDescriptor.sources,
argumentsChunked,
true));
formulaColumnPython.initDef(columnDefinitionMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import org.jpy.PyObject;

import java.util.List;
import java.util.Set;

/**
* Created by rbasralian on 8/12/23
Expand All @@ -19,8 +20,6 @@ public interface PyCallableWrapper {

Object call(Object... args);

List<Class<?>> getParamTypes();

boolean isVectorized();

boolean isVectorizable();
Expand All @@ -31,7 +30,46 @@ public interface PyCallableWrapper {

void addChunkArgument(ChunkArgument chunkArgument);

Class<?> getReturnType();
Signature getSignature();

void verifyArguments(Class<?>[] argTypes);

class Parameter {
private final String name;
private final Set<Class<?>> possibleTypes;


public Parameter(String name, Set<Class<?>> possibleTypes) {
this.name = name;
this.possibleTypes = possibleTypes;
}

public Set<Class<?>> getPossibleTypes() {
return possibleTypes;
}

public String getName() {
return name;
}
}

class Signature {
private final List<Parameter> parameters;
private final Class<?> returnType;

public Signature(List<Parameter> parameters, Class<?> returnType) {
this.parameters = parameters;
this.returnType = returnType;
}

public List<Parameter> getParameters() {
return parameters;
}

public Class<?> getReturnType() {
return returnType;
}
}

abstract class ChunkArgument {
private final Class<?> type;
Expand Down Expand Up @@ -88,4 +126,5 @@ public Object getValue() {
}

boolean isVectorizableReturnType();

}
Loading

0 comments on commit bcbeff7

Please sign in to comment.