From bcbeff73fb1a0810b060a814781d2ef0cce0d1d2 Mon Sep 17 00:00:00 2001 From: Jianfeng Mao <4297243+jmao-denver@users.noreply.github.com> Date: Mon, 25 Mar 2024 08:34:16 -0600 Subject: [PATCH] Check arg type against Py UDF signature at query compile time (#5254) * WIP check arg type against defs at parsing time * Fix spotless check errors * Fix bugs and old test cases * Refactor parsing code of UDF signatures * Refactoring interface * Code clean up/add more comments/test cases * Fix a test failure * Skip vermin check for explictly ver checked code * More novermin check * Make comment more clear and new test case --- .../table/impl/lang/QueryLanguageParser.java | 49 ++++-- .../impl/select/AbstractConditionFilter.java | 5 +- .../table/impl/select/DhFormulaColumn.java | 3 +- .../engine/util/PyCallableWrapper.java | 45 ++++- .../engine/util/PyCallableWrapperJpyImpl.java | 154 +++++++++++++----- .../impl/lang/PyCallableWrapperDummyImpl.java | 11 +- py/server/deephaven/_udf.py | 115 +++++++++++-- py/server/deephaven/dtypes.py | 69 +------- py/server/tests/test_numba_guvectorize.py | 13 +- ...est_udf_numpy_args.py => test_udf_args.py} | 117 +++++++++++-- .../tests/test_udf_return_java_values.py | 11 +- py/server/tests/test_vectorization.py | 5 +- 12 files changed, 434 insertions(+), 163 deletions(-) rename py/server/tests/{test_udf_numpy_args.py => test_udf_args.py} (79%) diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java index 0f636d9a5b5..ae037b986d6 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java @@ -153,7 +153,7 @@ public final class QueryLanguageParser extends GenericVisitorAdapter, Q private final Map> staticImportLookupCache = new HashMap<>(); // We need some class to represent null. We know for certain that this one won't be used... - private static final Class NULL_CLASS = QueryLanguageParser.class; + public static final Class NULL_CLASS = QueryLanguageParser.class; /** * The result of the QueryLanguageParser for the expression passed given to the constructor. @@ -1939,7 +1939,7 @@ private static boolean isAssociativitySafeExpression(Expression expr) { * @return {@code true} if a conversion from {@code original} to {@code target} is a widening conversion; otherwise, * {@code false}. */ - static boolean isWideningPrimitiveConversion(Class original, Class target) { + public static boolean isWideningPrimitiveConversion(Class original, Class target) { if (original == null || !original.isPrimitive() || target == null || !target.isPrimitive() || original.equals(void.class) || target.equals(void.class)) { throw new IllegalArgumentException("Arguments must be a primitive type (excluding void)!"); @@ -1968,6 +1968,7 @@ static boolean isWideningPrimitiveConversion(Class original, Class target) return false; } + private enum LanguageParserPrimitiveType { // Including "Enum" (or really, any differentiating string) in these names is important. They're used // in a switch() statement, which apparently does not support qualified names. And we can't use @@ -2498,6 +2499,7 @@ public Class visit(MethodCallExpr n, VisitArgs printer) { // Attempt python function call vectorization. if (scopeType != null && PyCallableWrapper.class.isAssignableFrom(scopeType)) { + verifyPyCallableArguments(n, argTypes); tryVectorizePythonCallable(n, scopeType, convertedArgExpressions, argTypes); } @@ -2505,6 +2507,29 @@ public Class visit(MethodCallExpr n, VisitArgs printer) { typeArguments); } + private void verifyPyCallableArguments(@NotNull MethodCallExpr n, @NotNull Class[] argTypes) { + final String invokedMethodName = n.getNameAsString(); + + if (GET_ATTRIBUTE_METHOD_NAME.equals(invokedMethodName)) { + // Currently Python UDF handling is only supported for top module level function(callable) calls. + // The getAttribute() calls which is needed to support Python method calls, which is beyond the scope of + // current implementation. So we are skipping the argument verification for getAttribute() calls. + return; + } + if (!n.containsData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS)) { + return; + } + final PyCallableDetails pyCallableDetails = n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS); + final String pyMethodName = pyCallableDetails.pythonMethodName; + final Object methodVar = queryScopeVariables.get(pyMethodName); + if (!(methodVar instanceof PyCallableWrapper)) { + return; + } + final PyCallableWrapper pyCallableWrapper = (PyCallableWrapper) methodVar; + pyCallableWrapper.parseSignature(); + pyCallableWrapper.verifyArguments(argTypes); + } + private Optional makeCastExpressionForPyCallable(Class retType, MethodCallExpr callMethodCall) { if (retType.isPrimitive()) { return Optional.of(new CastExpr( @@ -2552,7 +2577,7 @@ private Optional> pyCallableReturnType(@NotNull MethodCallExpr n) { } final PyCallableWrapper pyCallableWrapper = (PyCallableWrapper) paramValueRaw; pyCallableWrapper.parseSignature(); - return Optional.ofNullable(pyCallableWrapper.getReturnType()); + return Optional.ofNullable(pyCallableWrapper.getSignature().getReturnType()); } @NotNull @@ -2683,7 +2708,8 @@ private void checkVectorizability(@NotNull final MethodCallExpr n, pyCallableWrapper.parseSignature(); if (!pyCallableWrapper.isVectorizableReturnType()) { throw new PythonCallVectorizationFailure( - "Python function return type is not supported: " + pyCallableWrapper.getReturnType()); + "Python function return type is not supported: " + + pyCallableWrapper.getSignature().getReturnType()); } // Python vectorized functions(numba, DH) return arrays of primitive/Object types. This will break the generated @@ -2726,11 +2752,10 @@ private void checkVectorizability(@NotNull final MethodCallExpr n, } } - List> paramTypes = pyCallableWrapper.getParamTypes(); - if (paramTypes.size() != expressions.length) { + if (pyCallableWrapper.getSignature().getParameters().size() != expressions.length) { // note vectorization doesn't handle Python variadic arguments throw new PythonCallVectorizationFailure("Python function argument count mismatch: " + n + " " - + paramTypes.size() + " vs. " + expressions.length); + + pyCallableWrapper.getSignature().getParameters().size() + " vs. " + expressions.length); } } @@ -2739,10 +2764,9 @@ private void prepareVectorizationArgs( Expression[] expressions, Class[] argTypes, PyCallableWrapper pyCallableWrapper) { - List> paramTypes = pyCallableWrapper.getParamTypes(); - if (paramTypes.size() != expressions.length) { + if (pyCallableWrapper.getSignature().getParameters().size() != expressions.length) { throw new PythonCallVectorizationFailure("Python function argument count mismatch: " + n + " " - + paramTypes.size() + " vs. " + expressions.length); + + pyCallableWrapper.getSignature().getParameters().size() + " vs. " + expressions.length); } pyCallableWrapper.initializeChunkArguments(); @@ -2763,11 +2787,6 @@ private void prepareVectorizationArgs( } else { throw new IllegalStateException("Vectorizability check failed: " + n); } - - if (!isSafelyCoerceable(argTypes[i], paramTypes.get(i))) { - throw new PythonCallVectorizationFailure("Python vectorized function argument type mismatch: " + n + " " - + argTypes[i].getSimpleName() + " -> " + paramTypes.get(i).getSimpleName()); - } } } diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/select/AbstractConditionFilter.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/select/AbstractConditionFilter.java index 95098cdb5b4..481d25065ee 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/select/AbstractConditionFilter.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/select/AbstractConditionFilter.java @@ -267,7 +267,7 @@ private void checkAndInitializeVectorization(QueryLanguageParser.Result result, final PyCallableWrapperJpyImpl pyCallableWrapper = cws[0]; if (pyCallableWrapper.isVectorizable()) { - checkReturnType(result, pyCallableWrapper.getReturnType()); + checkReturnType(result, pyCallableWrapper.getSignature().getReturnType()); for (String variable : result.getVariablesUsed()) { if (variable.equals("i")) { @@ -284,7 +284,8 @@ private void checkAndInitializeVectorization(QueryLanguageParser.Result result, ArgumentsChunked argumentsChunked = pyCallableWrapper.buildArgumentsChunked(usedColumns); PyObject vectorized = pyCallableWrapper.vectorizedCallable(); DeephavenCompatibleFunction dcf = DeephavenCompatibleFunction.create(vectorized, - pyCallableWrapper.getReturnType(), usedColumns.toArray(new String[0]), argumentsChunked, true); + pyCallableWrapper.getSignature().getReturnType(), usedColumns.toArray(new String[0]), + argumentsChunked, true); setFilter(new ConditionFilter.ChunkFilter( dcf.toFilterKernel(), dcf.getColumnNames().toArray(new String[0]), diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/select/DhFormulaColumn.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/select/DhFormulaColumn.java index 6d788018888..e1babe4bec3 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/select/DhFormulaColumn.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/select/DhFormulaColumn.java @@ -238,7 +238,8 @@ private void checkAndInitializeVectorization(Map> co PyObject vectorized = pyCallableWrapper.vectorizedCallable(); formulaColumnPython = FormulaColumnPython.create(this.columnName, DeephavenCompatibleFunction.create(vectorized, - pyCallableWrapper.getReturnType(), this.analyzedFormula.sourceDescriptor.sources, + pyCallableWrapper.getSignature().getReturnType(), + this.analyzedFormula.sourceDescriptor.sources, argumentsChunked, true)); formulaColumnPython.initDef(columnDefinitionMap); diff --git a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapper.java b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapper.java index b9c24773981..bcbcfb5462a 100644 --- a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapper.java +++ b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapper.java @@ -6,6 +6,7 @@ import org.jpy.PyObject; import java.util.List; +import java.util.Set; /** * Created by rbasralian on 8/12/23 @@ -19,8 +20,6 @@ public interface PyCallableWrapper { Object call(Object... args); - List> getParamTypes(); - boolean isVectorized(); boolean isVectorizable(); @@ -31,7 +30,46 @@ public interface PyCallableWrapper { void addChunkArgument(ChunkArgument chunkArgument); - Class getReturnType(); + Signature getSignature(); + + void verifyArguments(Class[] argTypes); + + class Parameter { + private final String name; + private final Set> possibleTypes; + + + public Parameter(String name, Set> possibleTypes) { + this.name = name; + this.possibleTypes = possibleTypes; + } + + public Set> getPossibleTypes() { + return possibleTypes; + } + + public String getName() { + return name; + } + } + + class Signature { + private final List parameters; + private final Class returnType; + + public Signature(List parameters, Class returnType) { + this.parameters = parameters; + this.returnType = returnType; + } + + public List getParameters() { + return parameters; + } + + public Class getReturnType() { + return returnType; + } + } abstract class ChunkArgument { private final Class type; @@ -88,4 +126,5 @@ public Object getValue() { } boolean isVectorizableReturnType(); + } diff --git a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java index 6ba399302cb..84c75748aed 100644 --- a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java +++ b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java @@ -10,12 +10,10 @@ import org.jpy.PyObject; import java.time.Instant; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.util.*; + +import static io.deephaven.engine.table.impl.lang.QueryLanguageParser.NULL_CLASS; +import static io.deephaven.util.type.TypeUtils.getUnboxedType; /** * When given a pyObject that is a callable, we stick it inside the callable wrapper, which implements a call() varargs @@ -30,6 +28,7 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper { private static final PyModule dh_udf_module = PyModule.importModule("deephaven._udf"); private static final Map> numpyType2JavaClass = new HashMap<>(); + private static final Map> numpyType2JavaArrayClass = new HashMap<>(); static { numpyType2JavaClass.put('b', byte.class); @@ -43,6 +42,18 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper { numpyType2JavaClass.put('U', String.class); numpyType2JavaClass.put('M', Instant.class); numpyType2JavaClass.put('O', Object.class); + + numpyType2JavaArrayClass.put('b', byte[].class); + numpyType2JavaArrayClass.put('h', short[].class); + numpyType2JavaArrayClass.put('H', char[].class); + numpyType2JavaArrayClass.put('i', int[].class); + numpyType2JavaArrayClass.put('l', long[].class); + numpyType2JavaArrayClass.put('f', float[].class); + numpyType2JavaArrayClass.put('d', double[].class); + numpyType2JavaArrayClass.put('?', Boolean[].class); + numpyType2JavaArrayClass.put('U', String[].class); + numpyType2JavaArrayClass.put('M', Instant[].class); + numpyType2JavaArrayClass.put('O', Object[].class); } /** @@ -70,14 +81,12 @@ public static void init() {} @Override public boolean isVectorizableReturnType() { parseSignature(); - return vectorizableReturnTypes.contains(returnType); + return vectorizableReturnTypes.contains(signature.getReturnType()); } private final PyObject pyCallable; - - private String signature = null; - private List> paramTypes; - private Class returnType; + private String signatureString = null; + private Signature signature; private boolean vectorizable = false; private boolean vectorized = false; private Collection chunkArguments; @@ -110,7 +119,7 @@ public ArgumentsChunked buildArgumentsChunked(List columnNames) { ((ColumnChunkArgument) arg).setSourceChunkIndex(chunkSourceIndex); } } - return new ArgumentsChunked(chunkArguments, returnType, numbaVectorized); + return new ArgumentsChunked(chunkArguments, signature.getReturnType(), numbaVectorized); } /** @@ -168,12 +177,13 @@ private void prepareSignature() { vectorized = false; } pyUdfDecoratedCallable = dh_udf_module.call("_py_udf", unwrapped); - signature = pyUdfDecoratedCallable.getAttribute("signature").toString(); + signatureString = pyUdfDecoratedCallable.getAttribute("signature").toString(); } + @Override public void parseSignature() { - if (signature != null) { + if (signatureString != null) { return; } @@ -181,35 +191,110 @@ public void parseSignature() { // the 'types' field of a vectorized function follows the pattern of '[ilhfdb?O]*->[ilhfdb?O]', // eg. [ll->d] defines two int64 (long) arguments and a double return type. - if (signature == null || signature.isEmpty()) { + if (signatureString == null || signatureString.isEmpty()) { throw new IllegalStateException("Signature should always be available."); } - List> paramTypes = new ArrayList<>(); - for (char numpyTypeChar : signature.toCharArray()) { - if (numpyTypeChar != '-') { - Class paramType = numpyType2JavaClass.get(numpyTypeChar); - if (paramType == null) { - throw new IllegalStateException( - "Parameters of vectorized functions should always be of integral, floating point, boolean, String, or Object type: " - + numpyTypeChar + " of " + signature); + String pyEncodedParamsStr = signatureString.split("->")[0]; + List parameters = new ArrayList<>(); + if (!pyEncodedParamsStr.isEmpty()) { + String[] pyEncodedParams = pyEncodedParamsStr.split(","); + for (String pyEncodedParam : pyEncodedParams) { + String[] paramDetail = pyEncodedParam.split(":"); + String paramName = paramDetail[0]; + String paramTypeCodes = paramDetail[1]; + Set> possibleTypes = new HashSet<>(); + for (int ti = 0; ti < paramTypeCodes.length(); ti++) { + char typeCode = paramTypeCodes.charAt(ti); + if (typeCode == '[') { + // skip the array type code + ti++; + possibleTypes.add(numpyType2JavaArrayClass.get(paramTypeCodes.charAt(ti))); + } else if (typeCode == 'N') { + possibleTypes.add(NULL_CLASS); + } else { + possibleTypes.add(numpyType2JavaClass.get(typeCode)); + } } - paramTypes.add(paramType); - } else { - break; + parameters.add(new Parameter(paramName, possibleTypes)); } } - this.paramTypes = paramTypes; - - returnType = pyUdfDecoratedCallable.getAttribute("return_type", null); + Class returnType = pyUdfDecoratedCallable.getAttribute("return_type", null); if (returnType == null) { throw new IllegalStateException( "Python functions should always have an integral, floating point, boolean, String, arrays, or Object return type"); } if (returnType == boolean.class) { - this.returnType = Boolean.class; + returnType = Boolean.class; + } + + signature = new Signature(parameters, returnType); + + } + + private boolean isSafelyCastable(Set> types, Class type) { + for (Class t : types) { + if (t.isAssignableFrom(type)) { + return true; + } + if (t.isPrimitive() && type.isPrimitive() && isLosslessWideningPrimitiveConversion(type, t)) { + return true; + } + } + return false; + } + + public static boolean isLosslessWideningPrimitiveConversion(Class original, Class target) { + if (original == null || !original.isPrimitive() || target == null || !target.isPrimitive() + || original.equals(void.class) || target.equals(void.class)) { + throw new IllegalArgumentException("Arguments must be a primitive type (excluding void)!"); + } + + if (original.equals(target)) { + return true; + } + + if (original.equals(byte.class)) { + return target == short.class || target == int.class || target == long.class; + } else if (original.equals(short.class) || original.equals(char.class)) { // char is unsigned, so it's a + // lossless conversion to int + return target == int.class || target == long.class; + } else if (original.equals(int.class)) { + return target == long.class; + } else if (original.equals(float.class)) { + return target == double.class; + } + + return false; + } + + public void verifyArguments(Class[] argTypes) { + String callableName = pyCallable.getAttribute("__name__").toString(); + List parameters = signature.getParameters(); + + for (int i = 0; i < argTypes.length; i++) { + // if there are more arguments than parameters, we'll need to consider the last parameter as a varargs + // parameter. This is not ideal. We should consider a better way to handle this, i.e. a way to convey that + // the function is variadic. + Set> types = + parameters.get(Math.min(i, parameters.size() - 1)).getPossibleTypes(); + + // to prevent the unpacking of an array column when calling a Python function, we prefix the column accessor + // with a cast to generic Object type, until we can find a way to convey that info, we'll just skip the + // check for Object type input + if (argTypes[i] == Object.class) { + continue; + } + + Class t = getUnboxedType(argTypes[i]) == null ? argTypes[i] : getUnboxedType(argTypes[i]); + if (!types.contains(t) && !types.contains(Object.class) && !isSafelyCastable(types, t)) { + throw new IllegalArgumentException( + callableName + ": " + "Expected argument (" + parameters.get(i).getName() + ") to be one of " + + parameters.get(i).getPossibleTypes() + ", got " + + (argTypes[i].equals(NULL_CLASS) ? "null" : argTypes[i])); + } } } @@ -229,11 +314,6 @@ public Object call(Object... args) { return PythonScopeJpyImpl.convert(pyCallable.callMethod("__call__", args)); } - @Override - public List> getParamTypes() { - return paramTypes; - } - @Override public boolean isVectorized() { return vectorized; @@ -260,8 +340,8 @@ public void addChunkArgument(ChunkArgument chunkArgument) { } @Override - public Class getReturnType() { - return returnType; + public Signature getSignature() { + return signature; } } diff --git a/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java b/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java index 76ef69a0687..6943a880769 100644 --- a/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java +++ b/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java @@ -6,6 +6,7 @@ import io.deephaven.engine.util.PyCallableWrapper; import org.jpy.PyObject; +import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -44,7 +45,6 @@ public Object call(Object... args) { throw new UnsupportedOperationException(); } - @Override public List> getParamTypes() { return parameterTypes; } @@ -71,8 +71,13 @@ public void initializeChunkArguments() {} public void addChunkArgument(ChunkArgument ignored) {} @Override - public Class getReturnType() { - return Object.class; + public Signature getSignature() { + return new Signature(new ArrayList<>(), Void.class); + } + + @Override + public void verifyArguments(Class[] argTypes) { + } @Override diff --git a/py/server/deephaven/_udf.py b/py/server/deephaven/_udf.py index 2bfe947e3c5..fb8b77704b7 100644 --- a/py/server/deephaven/_udf.py +++ b/py/server/deephaven/_udf.py @@ -5,9 +5,13 @@ import inspect import re import sys +import typing from dataclasses import dataclass, field +from datetime import datetime from functools import wraps -from typing import Callable, List, Any, Union, Tuple, _GenericAlias, Set +from typing import Callable, List, Any, Union, Tuple, _GenericAlias, Set, Optional, Sequence + +import pandas as pd from deephaven._dep import soft_dependency @@ -17,9 +21,9 @@ import numpy as np from deephaven import DHError, dtypes -from deephaven.dtypes import _np_ndarray_component_type, _np_dtype_char, _NUMPY_INT_TYPE_CODES, \ - _NUMPY_FLOATING_TYPE_CODES, _component_np_dtype_char, _J_ARRAY_NP_TYPE_MAP, _PRIMITIVE_DTYPE_NULL_MAP, _scalar, \ - _BUILDABLE_ARRAY_DTYPE_MAP +from deephaven.dtypes import _NUMPY_INT_TYPE_CODES, _NUMPY_FLOATING_TYPE_CODES, _J_ARRAY_NP_TYPE_MAP, \ + _PRIMITIVE_DTYPE_NULL_MAP, _scalar, \ + _BUILDABLE_ARRAY_DTYPE_MAP, DType from deephaven.jcompat import _j_array_to_numpy_array from deephaven.time import to_np_datetime64 @@ -27,7 +31,6 @@ test_vectorization = False vectorized_count = 0 - _SUPPORTED_NP_TYPE_CODES = {"b", "h", "H", "i", "l", "f", "d", "?", "U", "M", "O"} @@ -64,7 +67,7 @@ def encoded(self) -> str: then the return type char. If a parameter or the return of the function is not annotated, the default 'O' - object type, will be used. """ - param_str = ",".join(["".join(p.encoded_types) for p in self.params]) + param_str = ",".join([str(p.name) + ":" + "".join(p.encoded_types) for p in self.params]) # ret_annotation has only one parsed annotation, and it might be Optional which means it contains 'N' in the # encoded type. We need to remove it. return_type_code = re.sub(r"[N]", "", self.ret_annotation.encoded_type) @@ -80,7 +83,7 @@ def _encode_param_type(t: type) -> str: return "N" # find the component type if it is numpy ndarray - component_type = _np_ndarray_component_type(t) + component_type = _component_np_dtype_char(t) if component_type: t = component_type @@ -92,6 +95,89 @@ def _encode_param_type(t: type) -> str: return tc +def _np_dtype_char(t: Union[type, str]) -> str: + """Returns the numpy dtype character code for the given type.""" + try: + np_dtype = np.dtype(t if t else "object") + if np_dtype.kind == "O": + if t in (datetime, pd.Timestamp): + return "M" + except TypeError: + np_dtype = np.dtype("object") + + return np_dtype.char + + +def _component_np_dtype_char(t: type) -> Optional[str]: + """Returns the numpy dtype character code for the given type's component type if the type is a Sequence type or + numpy ndarray, otherwise return None. """ + component_type = None + + if sys.version_info > (3, 8): + import types + if isinstance(t, types.GenericAlias) and issubclass(t.__origin__, Sequence): # novermin + component_type = t.__args__[0] + + if not component_type: + if isinstance(t, _GenericAlias) and issubclass(t.__origin__, Sequence): + component_type = t.__args__[0] + # if the component type is a DType, get its numpy type + if isinstance(component_type, DType): + component_type = component_type.np_type + + if not component_type: + if t == bytes or t == bytearray: + return "b" + + if not component_type: + component_type = _np_ndarray_component_type(t) + + if component_type: + return _np_dtype_char(component_type) + else: + return None + + +def _np_ndarray_component_type(t: type) -> Optional[type]: + """Returns the numpy ndarray component type if the type is a numpy ndarray, otherwise return None.""" + + # Py3.8: npt.NDArray can be used in Py 3.8 as a generic alias, but a specific alias (e.g. npt.NDArray[np.int64]) + # is an instance of a private class of np, yet we don't have a choice but to use it. And when npt.NDArray is used, + # the 1st argument is typing.Any, the 2nd argument is another generic alias of which the 1st argument is the + # component type + component_type = None + if sys.version_info.major == 3 and sys.version_info.minor == 8: + if isinstance(t, np._typing._generic_alias._GenericAlias) and t.__origin__ == np.ndarray: + component_type = t.__args__[1].__args__[0] + # Py3.9+, np.ndarray as a generic alias is only supported in Python 3.9+, also npt.NDArray is still available but a + # specific alias (e.g. npt.NDArray[np.int64]) now is an instance of typing.GenericAlias. + # when npt.NDArray is used, the 1st argument is typing.Any, the 2nd argument is another generic alias of which + # the 1st argument is the component type + # when np.ndarray is used, the 1st argument is the component type + if not component_type and sys.version_info.major == 3 and sys.version_info.minor > 8: + import types + if isinstance(t, types.GenericAlias) and t.__origin__ == np.ndarray: # novermin + nargs = len(t.__args__) + if nargs == 1: + component_type = t.__args__[0] + elif nargs == 2: # for npt.NDArray[np.int64], etc. + a0 = t.__args__[0] + a1 = t.__args__[1] + if a0 == typing.Any and isinstance(a1, types.GenericAlias): # novermin + component_type = a1.__args__[0] + return component_type + + +def _is_union_type(t: type) -> bool: + """Return True if the type is a Union type""" + if sys.version_info.major == 3 and sys.version_info.minor >= 10: + import types + if isinstance(t, types.UnionType): # novermin + return True + + return isinstance(t, _GenericAlias) and t.__origin__ == Union + + def _parse_param(name: str, annotation: Any) -> _ParsedParam: """ Parse a parameter annotation in a function's signature """ p_param = _ParsedParam(name) @@ -99,7 +185,7 @@ def _parse_param(name: str, annotation: Any) -> _ParsedParam: if annotation is inspect._empty: p_param.encoded_types.add("O") p_param.none_allowed = True - elif isinstance(annotation, _GenericAlias) and annotation.__origin__ == Union: + elif _is_union_type(annotation): for t in annotation.__args__: _parse_type_no_nested(annotation, p_param, t) else: @@ -149,7 +235,7 @@ def _parse_return_annotation(annotation: Any) -> _ParsedReturnAnnotation: t = annotation pra.orig_type = t - if isinstance(annotation, _GenericAlias) and annotation.__origin__ == Union and len(annotation.__args__) == 2: + if _is_union_type(annotation) and len(annotation.__args__) == 2: # if the annotation is a Union of two types, we'll use the non-None type if annotation.__args__[1] == type(None): # noqa: E721 t = annotation.__args__[0] @@ -170,7 +256,8 @@ def _parse_return_annotation(annotation: Any) -> _ParsedReturnAnnotation: if numba: - def _parse_numba_signature(fn: Union[numba.np.ufunc.gufunc.GUFunc, numba.np.ufunc.dufunc.DUFunc]) -> _ParsedSignature: + def _parse_numba_signature( + fn: Union[numba.np.ufunc.gufunc.GUFunc, numba.np.ufunc.dufunc.DUFunc]) -> _ParsedSignature: """ Parse a numba function's signature""" sigs = fn.types # in the format of ll->l, ff->f,dd->d,OO->O, etc. if sigs: @@ -261,7 +348,8 @@ def _parse_signature(fn: Callable) -> _ParsedSignature: t = eval(p.annotation, fn.__globals__) if isinstance(p.annotation, str) else p.annotation p_sig.params.append(_parse_param(n, t)) - t = eval(sig.return_annotation, fn.__globals__) if isinstance(sig.return_annotation, str) else sig.return_annotation + t = eval(sig.return_annotation, fn.__globals__) if isinstance(sig.return_annotation, + str) else sig.return_annotation p_sig.ret_annotation = _parse_return_annotation(t) return p_sig @@ -389,7 +477,6 @@ def _py_udf(fn: Callable): # build a signature string for vectorization by removing NoneType, array char '[', and comma from the encoded types # since vectorization only supports UDFs with a single signature and enforces an exact match, any non-compliant # signature (e.g. Union with more than 1 non-NoneType) will be rejected by the vectorizer. - sig_str_vectorization = re.sub(r"[\[N,]", "", p_sig.encoded) return_array = p_sig.ret_annotation.has_array ret_dtype = dtypes.from_np_dtype(np.dtype(p_sig.ret_annotation.encoded_type[-1])) @@ -414,7 +501,7 @@ def wrapper(*args, **kwargs): j_class = real_ret_dtype.qst_type.clazz() wrapper.return_type = j_class - wrapper.signature = sig_str_vectorization + wrapper.signature = p_sig.encoded return wrapper @@ -475,4 +562,4 @@ def wrapper(*args): global vectorized_count vectorized_count += 1 - return wrapper \ No newline at end of file + return wrapper diff --git a/py/server/deephaven/dtypes.py b/py/server/deephaven/dtypes.py index 7da00a0fe37..5aa6c8acca8 100644 --- a/py/server/deephaven/dtypes.py +++ b/py/server/deephaven/dtypes.py @@ -9,13 +9,10 @@ from __future__ import annotations import datetime -import sys -import typing -from typing import Any, Sequence, Callable, Dict, Type, Union, _GenericAlias, Optional +from typing import Any, Sequence, Callable, Dict, Type, Union, Optional import jpy import numpy as np -import numpy._typing as npt import pandas as pd from deephaven import DHError @@ -304,7 +301,7 @@ def array(dtype: DType, seq: Optional[Sequence], remap: Callable[[Any], Any] = N raise DHError(e, f"failed to create a Java {dtype.j_name} array.") from e -def from_jtype(j_class: Any) -> DType: +def from_jtype(j_class: Any) -> Optional[DType]: """ looks up a DType that matches the java type, if not found, creates a DType for it. """ if not j_class: return None @@ -391,65 +388,3 @@ def _scalar(x: Any, dtype: DType) -> Any: return x except: return x - - -def _np_dtype_char(t: Union[type, str]) -> str: - """Returns the numpy dtype character code for the given type.""" - try: - np_dtype = np.dtype(t if t else "object") - if np_dtype.kind == "O": - if t in (datetime.datetime, pd.Timestamp): - return "M" - except TypeError: - np_dtype = np.dtype("object") - - return np_dtype.char - - -def _component_np_dtype_char(t: type) -> Optional[str]: - """Returns the numpy dtype character code for the given type's component type if the type is a Sequence type or - numpy ndarray, otherwise return None. """ - component_type = None - if isinstance(t, _GenericAlias) and issubclass(t.__origin__, Sequence): - component_type = t.__args__[0] - # if the component type is a DType, get its numpy type - if isinstance(component_type, DType): - component_type = component_type.np_type - - if not component_type: - component_type = _np_ndarray_component_type(t) - - if component_type: - return _np_dtype_char(component_type) - else: - return None - - -def _np_ndarray_component_type(t: type) -> Optional[type]: - """Returns the numpy ndarray component type if the type is a numpy ndarray, otherwise return None.""" - - # Py3.8: npt.NDArray can be used in Py 3.8 as a generic alias, but a specific alias (e.g. npt.NDArray[np.int64]) - # is an instance of a private class of np, yet we don't have a choice but to use it. And when npt.NDArray is used, - # the 1st argument is typing.Any, the 2nd argument is another generic alias of which the 1st argument is the - # component type - component_type = None - if sys.version_info.major == 3 and sys.version_info.minor == 8: - if isinstance(t, np._typing._generic_alias._GenericAlias) and t.__origin__ == np.ndarray: - component_type = t.__args__[1].__args__[0] - # Py3.9+, np.ndarray as a generic alias is only supported in Python 3.9+, also npt.NDArray is still available but a - # specific alias (e.g. npt.NDArray[np.int64]) now is an instance of typing.GenericAlias. - # when npt.NDArray is used, the 1st argument is typing.Any, the 2nd argument is another generic alias of which - # the 1st argument is the component type - # when np.ndarray is used, the 1st argument is the component type - if not component_type and sys.version_info.major == 3 and sys.version_info.minor > 8: - import types - if isinstance(t, types.GenericAlias) and (issubclass(t.__origin__, Sequence) or t.__origin__ == np.ndarray): # novermin - nargs = len(t.__args__) - if nargs == 1: - component_type = t.__args__[0] - elif nargs == 2: # for npt.NDArray[np.int64], etc. - a0 = t.__args__[0] - a1 = t.__args__[1] - if a0 == typing.Any and isinstance(a1, types.GenericAlias): - component_type = a1.__args__[0] - return component_type diff --git a/py/server/tests/test_numba_guvectorize.py b/py/server/tests/test_numba_guvectorize.py index 096d88902f3..2fa6c42b0e9 100644 --- a/py/server/tests/test_numba_guvectorize.py +++ b/py/server/tests/test_numba_guvectorize.py @@ -66,7 +66,10 @@ def g(x, dummy, res): res[0] = min(x) res[1] = max(x) - t = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g(Y,dummy)") + # convert dummy to a Java array + # TODO this is a hack, we might want to add a helper function for QLP to call to get the type of a PyObject arg + j_array = dtypes.array(dtypes.int64, dummy) + t = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g(Y,j_array)") self.assertEqual(t.columns[2].data_type, dtypes.long_array) def test_np_on_java_array(self): @@ -78,7 +81,11 @@ def g(x, dummy, res): res[0] = np.min(x) res[1] = np.max(x) - t = empty_table(10).update(["X=i%3", "Y=ii"]).group_by("X").update("Z=g(Y,dummy)") + # convert dummy to a Java array + # TODO this is a hack, we might want to add a helper function for QLP to call to get the type of a PyObject arg + j_array = dtypes.array(dtypes.int64, dummy) + + t = empty_table(10).update(["X=i%3", "Y=ii"]).group_by("X").update("Z=g(Y,j_array)") self.assertEqual(t.columns[2].data_type, dtypes.long_array) def test_np_on_java_array2(self): @@ -99,7 +106,7 @@ def g(x, res): with self.assertRaises(DHError) as cm: t = empty_table(10).update(["X=i%3", "Y=(double)ii"]).group_by("X").update("Z=g(Y)") - self.assertIn("Argument 1", str(cm.exception)) + self.assertIn("g: Expected argument (1)", str(cm.exception)) if __name__ == '__main__': diff --git a/py/server/tests/test_udf_numpy_args.py b/py/server/tests/test_udf_args.py similarity index 79% rename from py/server/tests/test_udf_numpy_args.py rename to py/server/tests/test_udf_args.py index f77bc615def..0f8f875ffa8 100644 --- a/py/server/tests/test_udf_numpy_args.py +++ b/py/server/tests/test_udf_args.py @@ -2,13 +2,14 @@ # Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending # import typing -from typing import Optional, Union, Any +from typing import Optional, Union, Any, Sequence import unittest import numpy as np import numpy.typing as npt -from deephaven import empty_table, DHError, dtypes +from deephaven import empty_table, DHError, dtypes, new_table +from deephaven.column import int_col from deephaven.dtypes import double_array, int32_array, long_array, int16_array, char_array, int8_array, \ float32_array from tests.testbase import BaseTestCase @@ -218,7 +219,7 @@ def f11(p1: Union[float, np.float32]) -> bool: with self.assertRaises(DHError) as cm: t = empty_table(10).update(["X1 = f11(i)"]) - def f2(p1: Union[np.int16, np.float64]) -> Union[Optional[bool]]: + def f2(p1: Union[np.int32, np.float64]) -> Union[Optional[bool]]: return bool(p1) t = empty_table(10).update(["X1 = f2(i)"]) @@ -231,7 +232,7 @@ def f21(p1: Union[np.int16, np.float64]) -> Union[Optional[bool], int]: with self.assertRaises(DHError) as cm: t = empty_table(10).update(["X1 = f21(i)"]) - def f3(p1: Union[np.int16, np.float64], p2=None) -> bool: + def f3(p1: Union[np.int32, np.float64], p2=None) -> bool: return bool(p1) t = empty_table(10).update(["X1 = f3(i)"]) @@ -244,7 +245,7 @@ def f4(p1: Union[np.int16, np.float64], p2=None) -> bool: self.assertEqual(t.columns[0].data_type, dtypes.bool_) with self.assertRaises(DHError) as cm: t = empty_table(10).update(["X1 = f4(now())"]) - self.assertRegex(str(cm.exception), "Argument .* is not compatible with annotation*") + self.assertRegex(str(cm.exception), "f4: Expected .* got .*Instant") def f41(p1: Union[np.int16, np.float64, Union[Any]], p2=None) -> bool: return bool(p1) @@ -266,7 +267,7 @@ def f5(col1, col2: np.ndarray[np.int32]) -> bool: t = t.update(["X1 = f5(X, Y)"]) with self.assertRaises(DHError) as cm: t = t.update(["X1 = f5(X, null)"]) - self.assertRegex(str(cm.exception), "Argument .* is not compatible with annotation*") + self.assertRegex(str(cm.exception), "f5: Expected .* got null") def f51(col1, col2: Optional[np.ndarray[np.int32]]) -> bool: return np.nanmean(col2) == np.mean(col2) @@ -287,7 +288,7 @@ def f6(*args: np.int32, col2: np.ndarray[np.int32]) -> bool: with self.assertRaises(DHError) as cm: t1 = t.update(["X1 = f6(X, Y=null)"]) - self.assertIn("not compatible with annotation", str(cm.exception)) + self.assertIn("f6: Expected argument (col2) to be one of [class [I], got boolean", str(cm.exception)) def test_str_bool_datetime_array(self): with self.subTest("str"): @@ -299,7 +300,7 @@ def f1(p1: np.ndarray[str], p2=None) -> bool: self.assertEqual(t1.columns[2].data_type, dtypes.bool_) with self.assertRaises(DHError) as cm: t2 = t.update(["X1 = f1(null, Y )"]) - self.assertRegex(str(cm.exception), "Argument .* is not compatible with annotation*") + self.assertRegex(str(cm.exception), "f1: Expected .* got null") def f11(p1: Union[np.ndarray[str], None], p2=None) -> bool: return bool(len(p1)) if p1 is not None else False @@ -315,7 +316,7 @@ def f2(p1: np.ndarray[np.datetime64], p2=None) -> bool: self.assertEqual(t1.columns[2].data_type, dtypes.bool_) with self.assertRaises(DHError) as cm: t2 = t.update(["X1 = f2(null, Y )"]) - self.assertRegex(str(cm.exception), "Argument .* is not compatible with annotation*") + self.assertRegex(str(cm.exception), "f2: Expected .* got null") def f21(p1: Union[np.ndarray[np.datetime64], None], p2=None) -> bool: return bool(len(p1)) if p1 is not None else False @@ -337,7 +338,7 @@ def f3(p1: np.ndarray[np.bool_], p2=None) -> bool: self.assertEqual(t1.columns[2].data_type, dtypes.bool_) with self.assertRaises(DHError) as cm: t2 = t.update(["X1 = f3(null, Y )"]) - self.assertRegex(str(cm.exception), "Argument 'p1': None is not compatible with annotation") + self.assertRegex(str(cm.exception), "f3: Expected .* got null") def f31(p1: Optional[np.ndarray[bool]], p2=None) -> bool: return bool(len(p1)) if p1 is not None else False @@ -405,9 +406,9 @@ def f(x: {p_type}) -> bool: # note typing t = empty_table(1).update(["X = i", f"Y = f(({p_type})X)"]) self.assertEqual(1, t.to_string(cols="Y").count("true")) - - np_int_types = {"np.int8", "np.int16", "np.int32", "np.int64"} - for p_type in np_int_types: + def test_np_typehints(self): + widening_np_int_types = {"np.int32", "np.int64"} + for p_type in widening_np_int_types: with self.subTest(p_type): func_str = f""" def f(x: {p_type}) -> bool: # note typing @@ -417,8 +418,20 @@ def f(x: {p_type}) -> bool: # note typing t = empty_table(1).update(["X = i", f"Y = f(X)"]) self.assertEqual(1, t.to_string(cols="Y").count("true")) - np_floating_types = {"np.float32", "np.float64"} - for p_type in np_floating_types: + narrowing_np_int_types = {"np.int8", "np.int16"} + for p_type in narrowing_np_int_types: + with self.subTest(p_type): + func_str = f""" +def f(x: {p_type}) -> bool: # note typing + return type(x) == {p_type} +""" + exec(func_str, globals()) + with self.assertRaises(DHError) as cm: + t = empty_table(1).update(["X = i", f"Y = f(X)"]) + self.assertRegex(str(cm.exception), "f: Expect") + + widening_np_floating_types = {"np.float32", "np.float64"} + for p_type in widening_np_floating_types: with self.subTest(p_type): func_str = f""" def f(x: {p_type}) -> bool: # note typing @@ -428,5 +441,79 @@ def f(x: {p_type}) -> bool: # note typing t = empty_table(1).update(["X = i", f"Y = f((float)X)"]) self.assertEqual(1, t.to_string(cols="Y").count("true")) + int_to_floating_types = {"np.float32", "np.float64"} + for p_type in int_to_floating_types: + with self.subTest(p_type): + func_str = f""" +def f(x: {p_type}) -> bool: # note typing + return type(x) == {p_type} +""" + exec(func_str, globals()) + with self.assertRaises(DHError) as cm: + t = empty_table(1).update(["X = i", f"Y = f(X)"]) + self.assertRegex(str(cm.exception), "f: Expect") + + def test_sequence_args(self): + with self.subTest("Sequence"): + def f(x: Sequence[int]) -> bool: + return True + + with self.assertRaises(DHError) as cm: + t = empty_table(1).update(["X = i", "Y = f(ii)"]) + self.assertRegex(str(cm.exception), "f: Expect") + + t = empty_table(1).update(["X = i", "Y = ii"]).group_by("X").update(["Z = f(Y.toArray())"]) + self.assertEqual(t.columns[2].data_type, dtypes.bool_) + + with self.subTest("bytes"): + def f(x: bytes) -> bool: + return True + + with self.assertRaises(DHError) as cm: + t = empty_table(1).update(["X = i", "Y = f(ii)"]) + self.assertRegex(str(cm.exception), "f: Expect") + + t = empty_table(1).update(["X = i", "Y = (byte)(ii % 128)"]).group_by("X").update(["Z = f(Y.toArray())"]) + self.assertEqual(t.columns[2].data_type, dtypes.bool_) + + with self.subTest("bytearray"): + def f(x: bytearray) -> bool: + return True + + with self.assertRaises(DHError) as cm: + t = empty_table(1).update(["X = i", "Y = f(ii)"]) + self.assertRegex(str(cm.exception), "f: Expect") + + t = empty_table(1).update(["X = i", "Y = (byte)(ii % 128)"]).group_by("X").update(["Z = f(Y.toArray())"]) + self.assertEqual(t.columns[2].data_type, dtypes.bool_) + + def test_non_common_cases(self): + def f1(x: int) -> float: + ... + + def f2(x: float) -> int: + ... + + t = empty_table(1).update("X = f2(f1(ii))") + self.assertEqual(t.columns[0].data_type, dtypes.int_) + + def test_varargs(self): + cols = ["A", "B", "C", "D"] + + def my_sum(p1: np.int32, *args: np.int64) -> int: + return sum(args) + + t = new_table([int_col(c, [0, 1, 2, 3, 4, 5, 6]) for c in cols]) + result = t.update(f"X = my_sum({','.join(cols)})") + self.assertEqual(result.columns[4].data_type, dtypes.int64) + + def my_sum_error(p1: np.int32, *args: np.int16) -> int: + return sum(args) + with self.assertRaises(DHError) as cm: + t.update(f"X = my_sum_error({','.join(cols)})") + self.assertRegex(str(cm.exception), "my_sum_error: Expected argument .* got int") + + + if __name__ == "__main__": unittest.main() diff --git a/py/server/tests/test_udf_return_java_values.py b/py/server/tests/test_udf_return_java_values.py index 7c1f55c5827..6d897195e21 100644 --- a/py/server/tests/test_udf_return_java_values.py +++ b/py/server/tests/test_udf_return_java_values.py @@ -67,6 +67,15 @@ def test_array_return(self): t = empty_table(10).update(["X = i % 3", "Y = i"]).group_by("X").update(f"Z= fn(Y + 1)") self.assertEqual(t.columns[2].data_type, dh_dtype) + container_types = ["bytes", "bytearray"] + for container_type in container_types: + with self.subTest(container_type=container_type): + func_decl_str = f"""def fn(col) -> {container_type}:""" + func_body_str = f""" return {container_type}(col)""" + exec("\n".join([func_decl_str, func_body_str]), globals()) + t = empty_table(10).update(["X = i % 3", "Y = i"]).group_by("X").update(f"Z= fn(Y + 1)") + self.assertEqual(t.columns[2].data_type, dtypes.byte_array) + def test_scalar_return_class_method_not_supported(self): for dh_dtype, np_dtype in _J_TYPE_NP_DTYPE_MAP.items(): with self.subTest(dh_dtype=dh_dtype, np_dtype=np_dtype): @@ -287,7 +296,7 @@ def test_np_ufunc(self): nbsin = numba.vectorize([numba.float64(numba.float64)])(np.sin) # this is the workaround that utilizes vectorization and type inference - @numba.vectorize([numba.float64(numba.float64)], nopython=True) + @numba.vectorize([numba.float64(numba.int64)], nopython=True) def nbsin(x): return np.sin(x) t3 = empty_table(10).update(["X3 = nbsin(i)"]) diff --git a/py/server/tests/test_vectorization.py b/py/server/tests/test_vectorization.py index d7532647640..627ca44eb94 100644 --- a/py/server/tests/test_vectorization.py +++ b/py/server/tests/test_vectorization.py @@ -15,7 +15,7 @@ from deephaven._udf import _dh_vectorize as dh_vectorize from tests.testbase import BaseTestCase -from tests.test_udf_numpy_args import _J_TYPE_NULL_MAP, _J_TYPE_NP_DTYPE_MAP, _J_TYPE_J_ARRAY_TYPE_MAP +from tests.test_udf_args import _J_TYPE_NULL_MAP, _J_TYPE_NP_DTYPE_MAP, _J_TYPE_J_ARRAY_TYPE_MAP class VectorizationTestCase(BaseTestCase): @@ -253,6 +253,7 @@ def my_sum(*args): source = new_table([int_col(c, [0, 1, 2, 3, 4, 5, 6]) for c in cols]) result = source.update(f"X = my_sum({','.join(cols)})") self.assertEqual(len(cols) + 1, len(result.columns)) + self.assertEqual(_udf.vectorized_count, 0) def test_enclosed_by_parentheses(self): def sinc(x) -> np.double: @@ -269,7 +270,7 @@ def sinc2(x): self.assertEqual(t.columns[1].data_type, dtypes.PyObject) def test_optional_annotations(self): - def pyfunc(p1: np.int32, p2: np.int32, p3: Optional[np.int32]) -> Optional[int]: + def pyfunc(p1: np.int32, p2: np.int64, p3: Optional[np.int32]) -> Optional[int]: total = p1 + p2 + p3 return None if total % 3 == 0 else total