Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support calling numpy ufuncs assigned to top-level names in formulas #4759

Merged
merged 2 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions py/server/deephaven/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,15 @@ def _encode_signature(fn: Callable) -> str:

If a parameter or the return of the function is not annotated, the default 'O' - object type, will be used.
"""
sig = inspect.signature(fn)
try:
sig = inspect.signature(fn)
except:
# in case inspect.signature() fails, we'll just use the default 'O' - object type.
# numpy ufuncs actually have signature encoded in their 'types' attribute, we want to better support
# them in the future (https://github.com/deephaven/deephaven-core/issues/4762)
if type(fn) == np.ufunc:
jmao-denver marked this conversation as resolved.
Show resolved Hide resolved
return "O"*fn.nin + "->" + "O"
return "->O"
jmao-denver marked this conversation as resolved.
Show resolved Hide resolved

np_type_codes = []
for n, p in sig.parameters.items():
Expand Down Expand Up @@ -429,8 +437,8 @@ def _py_udf(fn: Callable):
if hasattr(fn, "return_type"):
return fn
ret_dtype = _udf_return_dtype(fn)
return_array = False

return_array = False
# If the function is a numba guvectorized function, examine the signature of the function to determine if it
# returns an array.
if isinstance(fn, numba.np.ufunc.gufunc.GUFunc):
Expand All @@ -439,12 +447,20 @@ def _py_udf(fn: Callable):
if rtype:
return_array = True
else:
return_annotation = _parse_annotation(inspect.signature(fn).return_annotation)
component_type = _component_np_dtype_char(return_annotation)
if component_type:
ret_dtype = dtypes.from_np_dtype(np.dtype(component_type))
if ret_dtype in _BUILDABLE_ARRAY_DTYPE_MAP:
return_array = True
try:
return_annotation = _parse_annotation(inspect.signature(fn).return_annotation)
except ValueError:
# the function has no return annotation, and since we can't know what the exact type is, the return type
# defaults to the generic object type therefore it is not an array of a specific type,
# but see (https://github.com/deephaven/deephaven-core/issues/4762) for future imporvement to better support
# numpy ufuncs.
pass
else:
component_type = _component_np_dtype_char(return_annotation)
if component_type:
ret_dtype = dtypes.from_np_dtype(np.dtype(component_type))
if ret_dtype in _BUILDABLE_ARRAY_DTYPE_MAP:
return_array = True

@wraps(fn)
def wrapper(*args, **kwargs):
Expand Down
21 changes: 21 additions & 0 deletions py/server/tests/test_pyfunc_return_java_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,27 @@ def f1(col) -> Optional[List[int]]:
self.assertEqual(t.columns[0].data_type, dtypes.long_array)
self.assertEqual(t.to_string().count("null"), 5)

def test_np_ufunc(self):
# no vectorization and no type inference
npsin = np.sin
t = empty_table(10).update(["X1 = npsin(i)"])
self.assertEqual(t.columns[0].data_type, dtypes.PyObject)
t2 = t.update("X2 = X1.getDoubleValue()")
jmao-denver marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(t2.columns[1].data_type, dtypes.double)

import numba

# numba vectorize decorator doesn't support numpy ufunc
with self.assertRaises(TypeError):
nbsin = numba.vectorize([numba.float64(numba.float64)])(np.sin)

# this is the workaround that utilizes vectorization and type inference
@numba.vectorize([numba.float64(numba.float64)], nopython=True)
def nbsin(x):
return np.sin(x)
t3 = empty_table(10).update(["X3 = nbsin(i)"])
self.assertEqual(t3.columns[0].data_type, dtypes.double)


if __name__ == '__main__':
unittest.main()
Loading