From 9645e8bf97b1332b98af7f546f2d264ecd9c1a2f Mon Sep 17 00:00:00 2001 From: jianfengmao Date: Mon, 20 Nov 2023 16:50:54 -0700 Subject: [PATCH] A bit of a milestone --- py/server/deephaven/dtypes.py | 49 +++++- py/server/deephaven/jcompat.py | 128 ++++++++++++++- py/server/deephaven/numpy.py | 32 +--- py/server/deephaven/pandas.py | 49 ++---- py/server/deephaven/table.py | 286 +++++++++++++++++++++++++-------- 5 files changed, 400 insertions(+), 144 deletions(-) diff --git a/py/server/deephaven/dtypes.py b/py/server/deephaven/dtypes.py index 5f5857ffdbe..010ed157d42 100644 --- a/py/server/deephaven/dtypes.py +++ b/py/server/deephaven/dtypes.py @@ -188,6 +188,18 @@ def __call__(self, *args, **kwargs): } +_J_ARRAY_NP_TYPE_MAP = { + boolean_array.j_type: np.dtype("?"), + byte_array.j_type: np.dtype("b"), + char_array.j_type: np.dtype("uint16"), + short_array.j_type: np.dtype("h"), + int32_array.j_type: np.dtype("i"), + long_array.j_type: np.dtype("l"), + float32_array.j_type: np.dtype("f"), + double_array.j_type: np.dtype("d") +} + + def null_remap(dtype: DType) -> Callable[[Any], Any]: """ Creates a null value remap function for the provided DType. @@ -329,6 +341,17 @@ def from_np_dtype(np_dtype: Union[np.dtype, pd.api.extensions.ExtensionDtype]) - _NUMPY_FLOATING_TYPE_CODES = ["f", "d"] +def _is_py_null(x: Any) -> bool: + """Checks if the value is a Python null value, i.e. None or NaN, or Pandas.NA.""" + if x is None: + return True + + try: + return pd.isna(x) + except ValueError: + return False + + def _scalar(x: Any, dtype: DType) -> Any: """Converts a Python value to a Java scalar value. It converts the numpy primitive types, string to their Python equivalents so that JPY can handle them. For datetime values, it converts them to Java Instant. @@ -336,7 +359,7 @@ def _scalar(x: Any, dtype: DType) -> Any: # NULL_BOOL will appear in Java as a byte value which causes a cast error. We just let JPY converts it to Java null # and the engine has casting logic to handle it. - if x is None and dtype != bool_ and _PRIMITIVE_DTYPE_NULL_MAP.get(dtype): + if x is None and dtype not in (bool_, char) and _PRIMITIVE_DTYPE_NULL_MAP.get(dtype): return _PRIMITIVE_DTYPE_NULL_MAP[dtype] try: @@ -354,6 +377,8 @@ def _scalar(x: Any, dtype: DType) -> Any: elif x.dtype.char == 'M': from deephaven.time import to_j_instant return to_j_instant(x) + elif x.dtype.char == 'H': # np.uint16 + return jpy.get_type("java.lang.Character")(int(x)) elif isinstance(x, (datetime.datetime, pd.Timestamp)): from deephaven.time import to_j_instant return to_j_instant(x) @@ -382,14 +407,26 @@ def _component_np_dtype_char(t: type) -> Optional[str]: if isinstance(t, _GenericAlias) and issubclass(t.__origin__, Sequence): component_type = t.__args__[0] + if not component_type: + component_type = _np_ndarray_component_type(t) + + if component_type: + return _np_dtype_char(component_type) + else: + return None + + +def _np_ndarray_component_type(t): + """Returns the numpy ndarray component type if the type is a numpy ndarray, otherwise return None.""" + # Py3.8: npt.NDArray can be used in Py 3.8 as a generic alias, but a specific alias (e.g. npt.NDArray[np.int64]) # is an instance of a private class of np, yet we don't have a choice but to use it. And when npt.NDArray is used, # the 1st argument is typing.Any, the 2nd argument is another generic alias of which the 1st argument is the # component type - if not component_type and sys.version_info.minor == 8: + component_type = None + if sys.version_info.minor == 8: if isinstance(t, np._typing._generic_alias._GenericAlias) and t.__origin__ == np.ndarray: component_type = t.__args__[1].__args__[0] - # Py3.9+, np.ndarray as a generic alias is only supported in Python 3.9+, also npt.NDArray is still available but a # specific alias (e.g. npt.NDArray[np.int64]) now is an instance of typing.GenericAlias. # when npt.NDArray is used, the 1st argument is typing.Any, the 2nd argument is another generic alias of which @@ -406,8 +443,4 @@ def _component_np_dtype_char(t: type) -> Optional[str]: a1 = t.__args__[1] if a0 == typing.Any and isinstance(a1, types.GenericAlias): component_type = a1.__args__[0] - - if component_type: - return _np_dtype_char(component_type) - else: - return None + return component_type diff --git a/py/server/deephaven/jcompat.py b/py/server/deephaven/jcompat.py index d12f0d01f64..1a9c923bf5c 100644 --- a/py/server/deephaven/jcompat.py +++ b/py/server/deephaven/jcompat.py @@ -5,12 +5,29 @@ """ This module provides Java compatibility support including convenience functions to create some widely used Java data structures from corresponding Python ones in order to be able to call Java methods. """ -from typing import Any, Callable, Dict, Iterable, List, Sequence, Set, TypeVar, Union +from typing import Any, Callable, Dict, Iterable, List, Sequence, Set, TypeVar, Union, Tuple, Literal import jpy +import numpy as np +import pandas as pd +from deephaven import dtypes, DHError from deephaven._wrapper import unwrap, wrap_j_object -from deephaven.dtypes import DType +from deephaven.dtypes import DType, _PRIMITIVE_DTYPE_NULL_MAP, _J_ARRAY_NP_TYPE_MAP + +_NULL_BOOLEAN_AS_BYTE = jpy.get_type("io.deephaven.util.BooleanUtils").NULL_BOOLEAN_AS_BYTE +_JPrimitiveArrayConversionUtility = jpy.get_type("io.deephaven.integrations.common.PrimitiveArrayConversionUtility") + +_DH_PANDAS_NULLABLE_TYPE_MAP: Dict[DType, pd.api.extensions.ExtensionDtype] = { + dtypes.bool_: pd.BooleanDtype, + dtypes.byte: pd.Int8Dtype, + dtypes.short: pd.Int16Dtype, + dtypes.char: pd.UInt16Dtype, + dtypes.int32: pd.Int32Dtype, + dtypes.int64: pd.Int64Dtype, + dtypes.float32: pd.Float32Dtype, + dtypes.float64: pd.Float64Dtype, +} def is_java_type(obj: Any) -> bool: @@ -181,11 +198,114 @@ def to_sequence(v: Union[T, Sequence[T]] = None, wrapped: bool = False) -> Seque return () if wrapped: if not isinstance(v, Sequence) or isinstance(v, str): - return (v, ) + return (v,) else: return tuple(v) if not isinstance(v, Sequence) or isinstance(v, str): - return (unwrap(v), ) + return (unwrap(v),) else: return tuple((unwrap(o) for o in v)) + + +def _j_array_to_numpy_array(dtype: DType, j_array: jpy.JType, conv_null: bool = False, no_promotion: bool = False) -> \ + np.ndarray: + """ Produces a numpy array from the DType and given Java array.""" + if dtype.is_primitive: + np_array = np.frombuffer(j_array, dtype.np_type) + elif dtype == dtypes.Instant: + longs = _JPrimitiveArrayConversionUtility.translateArrayInstantToLong(j_array) + np_long_array = np.frombuffer(longs, np.int64) + np_array = np_long_array.view(dtype.np_type) + elif dtype == dtypes.bool_: + bytes_ = _JPrimitiveArrayConversionUtility.translateArrayBooleanToByte(j_array) + np_array = np.frombuffer(bytes_, dtype.np_type) + elif dtype == dtypes.string: + np_array = np.array([s for s in j_array], dtypes.string.np_type) + elif dtype.np_type is not np.object_: + try: + np_array = np.frombuffer(j_array, dtype.np_type) + except: + np_array = np.array(j_array, np.object_) + else: + np_array = np.array(j_array, np.object_) + + if conv_null: + dh_null = _PRIMITIVE_DTYPE_NULL_MAP.get(dtype) + if dh_null: + if dtype in (dtypes.float32, dtypes.float64): + np_array = np.copy(np_array) + np_array[np_array == dh_null] = np.nan + else: + if no_promotion: + raise DHError(f"Java array contains Deephaven nulls for dtype {dtype}") + if dtype is dtypes.bool_: # promote boolean to float64 + np_array = np.frombuffer(np_array, np.byte) + if any(np_array[np_array == dh_null]): + raise DHError(f"Java array contains Deephaven nulls for dtype {dtype}") + + return np_array + + +def _j_array_to_series(dtype: DType, j_array: jpy.JType, conv_null: bool) -> pd.Series: + """Produce a copy of the specified Java array as a pandas.Series object. + + Args: + j_array (jpy.JType): the Java array + dtype (DType): the data type of the Java array + conv_null (bool): whether to check for Deephaven nulls in the data and automatically replace them with + pd.NA. + + Returns: + a pandas Series + + Raises: + DHError + """ + if conv_null and dtype == dtypes.bool_: + j_array = _JPrimitiveArrayConversionUtility.translateArrayBooleanToByte(j_array) + np_array = np.frombuffer(j_array, dtype=np.byte) + s = pd.Series(data=np_array, dtype=pd.Int8Dtype(), copy=False) + s.mask(s == _NULL_BOOLEAN_AS_BYTE, inplace=True) + return s.astype(pd.BooleanDtype(), copy=False) + + np_array = _j_array_to_numpy_array(dtype, j_array, conv_null=False) + if conv_null and (nv := _PRIMITIVE_DTYPE_NULL_MAP.get(dtype)) is not None: + pd_ex_dtype = _DH_PANDAS_NULLABLE_TYPE_MAP.get(dtype) + s = pd.Series(data=np_array, dtype=pd_ex_dtype(), copy=False) + s.mask(s == nv, inplace=True) + else: + s = pd.Series(data=np_array, copy=False) + + return s + + +def _convert_udf_args(args: Tuple[Any], fn_signature: str, null_value: Literal[np.nan, pd.NA, None]) -> List[Any]: + converted_args = [] + for arg, np_dtype_char in zip(args, fn_signature): + if np_dtype_char == 'O': + converted_args.append(arg) + elif src_np_dtype := _J_ARRAY_NP_TYPE_MAP.get(type(arg)): + # array types + np_dtype = np.dtype(np_dtype_char) + if src_np_dtype != np_dtype and np_dtype != np.object_: + raise DHError(f"Cannot convert Java array of type {src_np_dtype} to numpy array of type {np_dtype}") + dtype = dtypes.from_np_dtype(np_dtype) + if null_value is pd.NA: + converted_args.append(_j_array_to_series(dtype, arg, conv_null=True)) + else: # np.nan or None + converted_args.append(_j_array_to_numpy_array(dtype, arg, conv_null=bool(null_value))) + else: # scalar type or array types that don't need conversion + try: + np_dtype = np.dtype(np_dtype_char) + except TypeError: + converted_args.append(arg) + else: + dtype = dtypes.from_np_dtype(np_dtype) + if dtype is dtypes.bool_: + converted_args.append(null_value if arg is None else arg) + elif dh_null := _PRIMITIVE_DTYPE_NULL_MAP.get(dtype): + converted_args.append(null_value if arg == dh_null else arg) + else: + converted_args.append(arg) + return converted_args diff --git a/py/server/deephaven/numpy.py b/py/server/deephaven/numpy.py index 412b6e8b5ac..2b8fd215aca 100644 --- a/py/server/deephaven/numpy.py +++ b/py/server/deephaven/numpy.py @@ -4,17 +4,18 @@ """ This module supports the conversion between Deephaven tables and numpy arrays. """ import re +from functools import wraps from typing import List import jpy import numpy as np -from deephaven.dtypes import DType -from deephaven import DHError, dtypes, empty_table, new_table +from deephaven import DHError, dtypes, new_table from deephaven.column import Column, InputColumn -from deephaven.table import Table +from deephaven.dtypes import DType +from deephaven.jcompat import _j_array_to_numpy_array, _convert_udf_args +from deephaven.table import Table, _encode_signature -_JPrimitiveArrayConversionUtility = jpy.get_type("io.deephaven.integrations.common.PrimitiveArrayConversionUtility") _JDataAccessHelpers = jpy.get_type("io.deephaven.engine.table.impl.DataAccessHelpers") @@ -25,28 +26,9 @@ def _to_column_name(name: str) -> str: def column_to_numpy_array(col_def: Column, j_array: jpy.JType) -> np.ndarray: - """ Produces a numpy array from the given Java array and the Table column definition. """ + """ Produces a numpy array from the given Java array and the Table column definition.""" try: - if col_def.data_type.is_primitive: - np_array = np.frombuffer(j_array, col_def.data_type.np_type) - elif col_def.data_type == dtypes.Instant: - longs = _JPrimitiveArrayConversionUtility.translateArrayInstantToLong(j_array) - np_long_array = np.frombuffer(longs, np.int64) - np_array = np_long_array.view(col_def.data_type.np_type) - elif col_def.data_type == dtypes.bool_: - bytes_ = _JPrimitiveArrayConversionUtility.translateArrayBooleanToByte(j_array) - np_array = np.frombuffer(bytes_, col_def.data_type.np_type) - elif col_def.data_type == dtypes.string: - np_array = np.array([s for s in j_array], dtypes.string.np_type) - elif col_def.data_type.np_type is not np.object_: - try: - np_array = np.frombuffer(j_array, col_def.data_type.np_type) - except: - np_array = np.array(j_array, np.object_) - else: - np_array = np.array(j_array, np.object_) - - return np_array + return _j_array_to_numpy_array(col_def.data_type, j_array) except DHError: raise except Exception as e: diff --git a/py/server/deephaven/pandas.py b/py/server/deephaven/pandas.py index 883622ce27b..8626b999e11 100644 --- a/py/server/deephaven/pandas.py +++ b/py/server/deephaven/pandas.py @@ -3,7 +3,7 @@ # """ This module supports the conversion between Deephaven tables and pandas DataFrames. """ -from typing import List, Dict, Tuple, Literal +from typing import List, Literal import jpy import numpy as np @@ -13,26 +13,14 @@ from deephaven import DHError, new_table, dtypes, arrow from deephaven.column import Column from deephaven.constants import NULL_BYTE, NULL_SHORT, NULL_INT, NULL_LONG, NULL_FLOAT, NULL_DOUBLE, NULL_CHAR -from deephaven.dtypes import DType -from deephaven.numpy import column_to_numpy_array, _make_input_column +from deephaven.jcompat import _j_array_to_series +from deephaven.numpy import _make_input_column from deephaven.table import Table _NULL_BOOLEAN_AS_BYTE = jpy.get_type("io.deephaven.util.BooleanUtils").NULL_BOOLEAN_AS_BYTE -_JPrimitiveArrayConversionUtility = jpy.get_type("io.deephaven.integrations.common.PrimitiveArrayConversionUtility") _JDataAccessHelpers = jpy.get_type("io.deephaven.engine.table.impl.DataAccessHelpers") _is_dtype_backend_supported = pd.__version__ >= "2.0.0" -_DTYPE_NULL_MAPPING: Dict[DType, Tuple] = { - dtypes.bool_: (_NULL_BOOLEAN_AS_BYTE, pd.BooleanDtype), - dtypes.byte: (NULL_BYTE, pd.Int8Dtype), - dtypes.short: (NULL_SHORT, pd.Int16Dtype), - dtypes.char: (NULL_CHAR, pd.UInt16Dtype), - dtypes.int32: (NULL_INT, pd.Int32Dtype), - dtypes.int64: (NULL_LONG, pd.Int64Dtype), - dtypes.float32: (NULL_FLOAT, pd.Float32Dtype), - dtypes.float64: (NULL_DOUBLE, pd.Float64Dtype), -} - def _column_to_series(table: Table, col_def: Column, conv_null: bool) -> pd.Series: """Produce a copy of the specified column as a pandas.Series object. @@ -51,29 +39,15 @@ def _column_to_series(table: Table, col_def: Column, conv_null: bool) -> pd.Seri """ try: data_col = _JDataAccessHelpers.getColumn(table.j_table, col_def.name) - if conv_null and col_def.data_type == dtypes.bool_: - j_array = _JPrimitiveArrayConversionUtility.translateArrayBooleanToByte(data_col.getDirect()) - np_array = np.frombuffer(j_array, dtype=np.byte) - s = pd.Series(data=np_array, dtype=pd.Int8Dtype(), copy=False) - s.mask(s == _NULL_BOOLEAN_AS_BYTE, inplace=True) - return s.astype(pd.BooleanDtype(), copy=False) - - np_array = column_to_numpy_array(col_def, data_col.getDirect()) - if conv_null and (null_pair := _DTYPE_NULL_MAPPING.get(col_def.data_type)) is not None: - nv = null_pair[0] - pd_ex_dtype = null_pair[1] - s = pd.Series(data=np_array, dtype=pd_ex_dtype(), copy=False) - s.mask(s == nv, inplace=True) - else: - s = pd.Series(data=np_array, copy=False) - return s + j_array = data_col.getDirect() + return _j_array_to_series(col_def.data_type, j_array, conv_null) except DHError: raise except Exception as e: raise DHError(e, message="failed to create a pandas Series for {col}") from e -_DTYPE_MAPPING_PYARROW = { +_PANDAS_ARROW_TYPE_MAP = { pa.int8(): pd.ArrowDtype(pa.int8()), pa.int16(): pd.ArrowDtype(pa.int16()), pa.int32(): pd.ArrowDtype(pa.int32()), @@ -90,7 +64,7 @@ def _column_to_series(table: Table, col_def: Column, conv_null: bool) -> pd.Seri pa.timestamp('ns', tz='UTC'): pd.ArrowDtype(pa.timestamp('ns', tz='UTC')), } -_DTYPE_MAPPING_NUMPY_NULLABLE = { +_PANDAS_NULLABLE_TYPE_MAP = { pa.int8(): pd.Int8Dtype(), pa.int16(): pd.Int16Dtype(), pa.uint16(): pd.UInt16Dtype(), @@ -107,8 +81,8 @@ def _column_to_series(table: Table, col_def: Column, conv_null: bool) -> pd.Seri } _PYARROW_TO_PANDAS_TYPE_MAPPERS = { - "pyarrow": _DTYPE_MAPPING_PYARROW.get, - "numpy_nullable": _DTYPE_MAPPING_NUMPY_NULLABLE.get, + "pyarrow": _PANDAS_ARROW_TYPE_MAP.get, + "numpy_nullable": _PANDAS_NULLABLE_TYPE_MAP.get, } @@ -180,7 +154,7 @@ def to_pandas(table: Table, cols: List[str] = None, raise DHError(e, "failed to create a pandas DataFrame from table.") from e -_EX_DTYPE_NULL_MAP = { +_PANDAS_EXTYPE_DH_NULL_MAP = { # This reflects the fact that in the server we use NULL_BOOLEAN_AS_BYTE - the byte encoding of null boolean to # translate boxed Boolean to/from primitive bytes pd.BooleanDtype: _NULL_BOOLEAN_AS_BYTE, @@ -209,7 +183,7 @@ def _map_na(array: [np.ndarray, pd.api.extensions.ExtensionArray]): if not isinstance(pd_dtype, pd.api.extensions.ExtensionDtype): return array - dh_null = _EX_DTYPE_NULL_MAP.get(type(pd_dtype)) or _EX_DTYPE_NULL_MAP.get(pd_dtype) + dh_null = _PANDAS_EXTYPE_DH_NULL_MAP.get(type(pd_dtype)) or _PANDAS_EXTYPE_DH_NULL_MAP.get(pd_dtype) # To preserve NaNs in floating point arrays, Pandas doesn't distinguish NaN/Null as far as NA testing is # concerned, thus its fillna() method will replace both NaN/Null in the data. if isinstance(pd_dtype, (pd.Float32Dtype, pd.Float64Dtype)) and isinstance(getattr(array, "_data"), np.ndarray): @@ -276,3 +250,4 @@ def to_table(df: pd.DataFrame, cols: List[str] = None) -> Table: raise except Exception as e: raise DHError(e, "failed to create a Deephaven Table from a pandas DataFrame.") from e + diff --git a/py/server/deephaven/table.py b/py/server/deephaven/table.py index e46348667b9..f0d9fa8a2da 100644 --- a/py/server/deephaven/table.py +++ b/py/server/deephaven/table.py @@ -9,14 +9,16 @@ import contextlib import inspect +from dataclasses import dataclass, field from enum import Enum from enum import auto from functools import wraps -from typing import Any, Optional, Callable, Dict, _GenericAlias +from typing import Any, Optional, Callable, Dict, _GenericAlias, Set, Tuple from typing import Sequence, List, Union, Protocol import jpy import numba +import numpy import numpy as np from deephaven import DHError @@ -27,12 +29,14 @@ from deephaven.agg import Aggregation from deephaven.column import Column, ColumnType from deephaven.filters import Filter, and_, or_ -from deephaven.jcompat import j_unary_operator, j_binary_operator, j_map_to_dict, j_hashmap +from deephaven.jcompat import j_unary_operator, j_binary_operator, j_map_to_dict, j_hashmap, _convert_udf_args, \ + _j_array_to_numpy_array from deephaven.jcompat import to_sequence, j_array_list +from deephaven.time import to_np_datetime64 from deephaven.update_graph import auto_locking_ctx, UpdateGraph from deephaven.updateby import UpdateByOperation -from deephaven.dtypes import _BUILDABLE_ARRAY_DTYPE_MAP, _scalar, _np_dtype_char, \ - _component_np_dtype_char +from deephaven.dtypes import _BUILDABLE_ARRAY_DTYPE_MAP, _scalar, _np_dtype_char, _component_np_dtype_char, DType, \ + _np_ndarray_component_type, _J_ARRAY_NP_TYPE_MAP, _PRIMITIVE_DTYPE_NULL_MAP # Table _J_Table = jpy.get_type("io.deephaven.engine.table.Table") @@ -363,21 +367,7 @@ def _j_py_script_session() -> _JPythonScriptSession: return None -_SUPPORTED_NP_TYPE_CODES = ["i", "l", "h", "f", "d", "b", "?", "U", "M", "O"] - - -def _parse_annotation(annotation: Any) -> Any: - """Parse a Python annotation, for now mostly to extract the non-None type from an Optional(Union) annotation, - otherwise return the original annotation. """ - if isinstance(annotation, _GenericAlias) and annotation.__origin__ == Union and len(annotation.__args__) == 2: - if annotation.__args__[1] == type(None): # noqa: E721 - return annotation.__args__[0] - elif annotation.__args__[0] == type(None): # noqa: E721 - return annotation.__args__[1] - else: - return annotation - else: - return annotation +_SUPPORTED_NP_TYPE_CODES = {"b", "h", "i", "l", "f", "d", "?", "U", "M", "O"} def _encode_signature(fn: Callable) -> str: @@ -394,15 +384,15 @@ def _encode_signature(fn: Callable) -> str: # numpy ufuncs actually have signature encoded in their 'types' attribute, we want to better support # them in the future (https://github.com/deephaven/deephaven-core/issues/4762) if type(fn) == np.ufunc: - return "O"*fn.nin + "->" + "O" + return "O" * fn.nin + "->" + "O" return "->O" np_type_codes = [] for n, p in sig.parameters.items(): - p_annotation = _parse_annotation(p.annotation) + p_annotation = _parse_param_annotation(p.annotation) np_type_codes.append(_np_dtype_char(p_annotation)) - return_annotation = _parse_annotation(sig.return_annotation) + return_annotation = _parse_param_annotation(sig.return_annotation) return_type_code = _np_dtype_char(return_annotation) np_type_codes = [c if c in _SUPPORTED_NP_TYPE_CODES else "O" for c in np_type_codes] return_type_code = return_type_code if return_type_code in _SUPPORTED_NP_TYPE_CODES else "O" @@ -411,11 +401,188 @@ def _encode_signature(fn: Callable) -> str: return "".join(np_type_codes) -def _udf_return_dtype(fn): +@dataclass +class ParsedAnnotation: + orig_types: set[type] = field(default_factory=set) + encoded_types: set[str] = field(default_factory=set) + is_optional: bool = False + is_array: bool = False + + +@dataclass +class ParsedSignature: + fn: Callable = None + parameters: List[ParsedAnnotation] = field(default_factory=list) + return_annotation: ParsedAnnotation = None + + +def _encode_param_type(t: type) -> str: + """Returns the numpy based char codes for the given type. + If the type is a numpy ndarray, prefix the numpy dtype char with '[' using Java convention + If the type is a NoneType, return 'N' + """ + if t is type(None): + return "N" + + # find the component type if it is numpy ndarray + component_type = _np_ndarray_component_type(t) + if component_type: + t = component_type + + tc = _np_dtype_char(t) + tc = tc if tc in _SUPPORTED_NP_TYPE_CODES else "O" + + if component_type: + tc = "[" + tc + return tc + + +def _parse_param_annotation(annotation: Any) -> ParsedAnnotation: + """ Parse an annotation in a function's signature """ + pa = ParsedAnnotation() + + # in the absence of annotations, we'll use the 'n' to indicate that. + if annotation is inspect._empty: + pa.encoded_types.add("n") + elif isinstance(annotation, _GenericAlias) and annotation.__origin__ == Union: + for t in annotation.__args__: + pa.orig_types.add(t) + tc = _encode_param_type(t) + if "[" in tc: + pa.is_array = True + elif tc == "N": + pa.is_optional = True + pa.encoded_types.add(tc) + else: + pa.orig_types.add(annotation) + pa.encoded_types.add(_encode_param_type(annotation)) + return pa + + +def _parse_return_annotation(annotation: Any) -> ParsedAnnotation: + """ Parse an annotation in a function's signature """ + pa = ParsedAnnotation() + + t = annotation + pa.orig_types.add(t) + if isinstance(annotation, _GenericAlias) and annotation.__origin__ == Union and len(annotation.__args__) == 2: + if annotation.__args__[1] == type(None): # noqa: E721 + t = annotation.__args__[0] + elif annotation.__args__[0] == type(None): # noqa: E721 + t = annotation.__args__[1] + component_char = _component_np_dtype_char(t) + if component_char: + pa.encoded_types.add("[" + component_char) + pa.is_array = True + else: + pa.encoded_types.add(_np_dtype_char(t)) + return pa + + +def _parse_signature(fn: Callable) -> ParsedSignature: + """ Parse the signature of a function, return a ParsedSignature object """ + + parsed_signature = ParsedSignature(fn=fn) + parsed_annotations = [] + if isinstance(fn, (numba.np.ufunc.gufunc.GUFunc, numba.np.ufunc.dufunc.DUFunc)): + sig = fn.signature + rtype = sig.split("->")[-1] + for p in sig.split("(")[1].split(")")[0].split(","): + parsed_annotations.append(_parse_param_annotation(p)) + elif isinstance(fn, numpy.ufunc): + # in case inspect.signature() fails, we'll just use the default 'O' - object type. + # numpy ufuncs actually have signature encoded in their 'types' attribute, we want to better support + # them in the future (https://github.com/deephaven/deephaven-core/issues/4762) + if type(fn) == np.ufunc: + return [ParsedAnnotation()] * fn.nin + "->" + "O" + return "->O" + else: + sig = inspect.signature(fn) + for n, p in sig.parameters.items(): + parsed_annotations.append(_parse_param_annotation(p.annotation)) + parsed_signature.parameters = parsed_annotations + parsed_signature.return_annotation = _parse_return_annotation(sig.return_annotation) + + if len(parsed_signature.return_annotation.orig_types) > 1: + raise ValueError("only single return type is supported.") + + return parsed_signature + + +def _udf_return_dtype(fn: Callable, signature: str) -> dtypes.Dtype: if isinstance(fn, (numba.np.ufunc.dufunc.DUFunc, numba.np.ufunc.gufunc.GUFunc)) and hasattr(fn, "types"): return dtypes.from_np_dtype(np.dtype(fn.types[0][-1])) else: - return dtypes.from_np_dtype(np.dtype(_encode_signature(fn)[-1])) + return dtypes.from_np_dtype(np.dtype(signature[-1])) + + +def _convert_arg(param: ParsedAnnotation, arg: Any) -> Any: + """ Convert a single argument to the type specified by the annotation """ + if arg is None: + if "0" in param.encoded_types or "n" in param.encoded_types or "[" in param.encoded_types: + return None + else: + raise TypeError(f"Argument {arg} is not compatible with annotation {param.orig_types}") + + # if the arg is a Java array + if np_dtype := _J_ARRAY_NP_TYPE_MAP.get(type(arg)): + encoded_type = "[" + np_dtype.char + # if it matches one of the encoded types, convert it + if encoded_type in param.encoded_types: + dtype = dtypes.from_np_dtype(np_dtype) + return _j_array_to_numpy_array(dtype, arg, conv_null=True, no_promotion=True) + # if the annotation is missing, or it is a generic object type, return the arg + elif "0" in param.encoded_types or "n" in param.encoded_types: + return arg + else: + raise TypeError(f"Argument {arg} is not compatible with annotation {param.orig_types}") + else: # if the arg is not Java array + # find the numpy dtype for the annotation + # if found, convert the arg to that type, take care of nulls (if null and optional, return None, else raise) + # if not found, if empyt annotation, return arg, else return + # if dh_null := _PRIMITIVE_DTYPE_NULL_MAP.get(dtype): + possible_types = param.encoded_types - {"n", "O"} + if possible_types: + for t in possible_types: + if t.startswith("["): + continue + + dtype = dtypes.from_np_dtype(np.dtype(t)) + dh_null = _PRIMITIVE_DTYPE_NULL_MAP.get(dtype) + + if t in {"b", "h", "i", "l"}: + if isinstance(arg, int): + if arg == dh_null: + if "N" in param.encoded_types: + return None + else: + raise TypeError(f"Argument {arg} is not compatible with annotation {param.orig_types}") + else: + return arg + elif t in {"f", "d"}: + if isinstance(arg, float): + if arg == dh_null: + return np.nan if "N" not in param.encoded_types else None + else: + return arg + elif t == "?": + if isinstance(arg, bool): + return arg + elif t == "M": + return to_np_datetime64(arg) + elif t == "U": + return str(arg) + else: + raise TypeError(f"Argument {arg} is not compatible with annotation {param.orig_types}") + else: + return arg + + +def _convert_args(parsed_sig: ParsedSignature, args: Tuple[Any, ...]) -> List[Any, ...]: + converted_args = [] + for arg, param in zip(args, parsed_sig.parameters): + converted_args.append(_convert_arg(param, arg)) + return converted_args def _py_udf(fn: Callable): @@ -423,48 +590,25 @@ def _py_udf(fn: Callable): Python and Java. This decorator is intended for use by the Deephaven query engine and should not be used by users. - For now, this decorator is only capable of converting Python function return values to Java values. It - does not yet convert Java values in arguments to usable Python object (e.g. numpy arrays) or properly translate - Deephaven primitive null values. - - For properly annotated functions, including numba vectorized and guvectorized ones, this decorator inspects the - signature of the function and determines its return type, including supported primitive types and arrays of - the supported primitive types. It then converts the return value of the function to the corresponding Java value - of the same type. For unsupported types, the decorator returns the original Python value which appears as - org.jpy.PyObject in Java. + It carries out two conversions: + 1. convert Python function return values to Java values. + For properly annotated functions, including numba vectorized and guvectorized ones, this decorator inspects the + signature of the function and determines its return type, including supported primitive types and arrays of + the supported primitive types. It then converts the return value of the function to the corresponding Java value + of the same type. For unsupported types, the decorator returns the original Python value which appears as + org.jpy.PyObject in Java. + 2. convert Java function arguments to Python values based on the signature of the function. """ - if hasattr(fn, "return_type"): return fn - ret_dtype = _udf_return_dtype(fn) - - return_array = False - # If the function is a numba guvectorized function, examine the signature of the function to determine if it - # returns an array. - if isinstance(fn, numba.np.ufunc.gufunc.GUFunc): - sig = fn.signature - rtype = sig.split("->")[-1].strip("()") - if rtype: - return_array = True - else: - try: - return_annotation = _parse_annotation(inspect.signature(fn).return_annotation) - except ValueError: - # the function has no return annotation, and since we can't know what the exact type is, the return type - # defaults to the generic object type therefore it is not an array of a specific type, - # but see (https://github.com/deephaven/deephaven-core/issues/4762) for future imporvement to better support - # numpy ufuncs. - pass - else: - component_type = _component_np_dtype_char(return_annotation) - if component_type: - ret_dtype = dtypes.from_np_dtype(np.dtype(component_type)) - if ret_dtype in _BUILDABLE_ARRAY_DTYPE_MAP: - return_array = True + parsed_signature = _parse_signature(fn) + return_array = parsed_signature.return_annotation.is_array + ret_dtype = dtypes.from_np_dtype(np.dtype(list(parsed_signature.return_annotation.encoded_types)[0][-1])) @wraps(fn) def wrapper(*args, **kwargs): - ret = fn(*args, **kwargs) + converted_args = _convert_args(parsed_signature, args) + ret = fn(*converted_args, **kwargs) if return_array: return dtypes.array(ret_dtype, ret) elif ret_dtype == dtypes.PyObject: @@ -473,7 +617,7 @@ def wrapper(*args, **kwargs): return _scalar(ret, ret_dtype) wrapper.j_name = ret_dtype.j_name - real_ret_dtype = _BUILDABLE_ARRAY_DTYPE_MAP.get(ret_dtype) if return_array else ret_dtype + real_ret_dtype = _BUILDABLE_ARRAY_DTYPE_MAP.get(ret_dtype, dtypes.PyObject) if return_array else ret_dtype if hasattr(ret_dtype.j_type, 'jclass'): j_class = real_ret_dtype.j_type.jclass @@ -500,14 +644,14 @@ def dh_vectorize(fn): The current vectorized function signature includes (1) the size of the input arrays, (2) the output array, and (3) the input arrays. """ - signature = _encode_signature(fn) - ret_dtype = _udf_return_dtype(fn) + fn_signature = _encode_signature(fn) + ret_dtype = _udf_return_dtype(fn, signature=fn_signature) @wraps(fn) def wrapper(*args): - if len(args) != len(signature) - len("->?") + 2: + if len(args) != len(fn_signature) - len("->?") + 2: raise ValueError( - f"The number of arguments doesn't match the function signature. {len(args) - 2}, {signature}") + f"The number of arguments doesn't match the function signature. {len(args) - 2}, {fn_signature}") if args[0] <= 0: raise ValueError(f"The chunk size argument must be a positive integer. {args[0]}") @@ -525,7 +669,7 @@ def wrapper(*args): return chunk_result wrapper.callable = fn - wrapper.signature = signature + wrapper.signature = fn_signature wrapper.dh_vectorized = True if _test_vectorization: @@ -3694,6 +3838,7 @@ def update_by(self, ops: Union[UpdateByOperation, List[UpdateByOperation]], except Exception as e: raise DHError(e, "update-by operation on the PartitionedTableProxy failed.") from e + class MultiJoinInput(JObjectWrapper): """A MultiJoinInput represents the input tables, key columns and additional columns to be used in the multi-table natural join. """ @@ -3761,7 +3906,8 @@ def __init__(self, input: Union[Table, Sequence[Table], MultiJoinInput, Sequence with auto_locking_ctx(*tables): j_tables = to_sequence(input) self.j_multijointable = _JMultiJoinFactory.of(on, *j_tables) - elif isinstance(input, MultiJoinInput) or (isinstance(input, Sequence) and all(isinstance(ji, MultiJoinInput) for ji in input)): + elif isinstance(input, MultiJoinInput) or ( + isinstance(input, Sequence) and all(isinstance(ji, MultiJoinInput) for ji in input)): if on is not None: raise DHError(message="on parameter is not permitted when MultiJoinInput objects are provided.") wrapped_input = to_sequence(input, wrapped=True) @@ -3770,13 +3916,13 @@ def __init__(self, input: Union[Table, Sequence[Table], MultiJoinInput, Sequence input = to_sequence(input) self.j_multijointable = _JMultiJoinFactory.of(*input) else: - raise DHError(message="input must be a Table, a sequence of Tables, a MultiJoinInput, or a sequence of MultiJoinInputs.") + raise DHError( + message="input must be a Table, a sequence of Tables, a MultiJoinInput, or a sequence of MultiJoinInputs.") except Exception as e: raise DHError(e, "failed to build a MultiJoinTable object.") from e - def multi_join(input: Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]], on: Union[str, Sequence[str]] = None) -> MultiJoinTable: """ The multi_join method creates a new table by performing a multi-table natural join on the input tables. The result @@ -3794,4 +3940,4 @@ def multi_join(input: Union[Table, Sequence[Table], MultiJoinInput, Sequence[Mul MultiJoinTable: the result of the multi-table natural join operation. To access the underlying Table, use the table() method. """ - return MultiJoinTable(input, on) \ No newline at end of file + return MultiJoinTable(input, on)