Skip to content

Commit

Permalink
auto convert Java values(arrays/scalar) to Numpy ones and convert DH …
Browse files Browse the repository at this point in the history
…nulls based on the annotations of the params of a Py UDF (#4502)

* A bit of a milestone

* made test suite pass

* Refactor the new code

* Add more tests

* More refactoring and code cleanup

* Fix a bug that fails vectorization

* More code cleanup and clarification

* More pathological test cases

* Fix String/Instant array conversion issue

* Fix test failures and refactor code

* Trivial renaming

* Respond to review comments

* Apply suggestions from code review

Co-authored-by: Chip Kent <[email protected]>

* Refactor the code and a minor fixes

* Improve the test cases

* Clearly distinguqish between params and return

* Clarify some code with comments

* More clarifying comments

---------

Co-authored-by: Chip Kent <[email protected]>
  • Loading branch information
jmao-denver and chipkent committed Dec 1, 2023
1 parent 6d1d1c9 commit 19b1d7e
Show file tree
Hide file tree
Showing 11 changed files with 1,061 additions and 315 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,18 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper {
private static final PyObject NUMBA_VECTORIZED_FUNC_TYPE = getNumbaVectorizedFuncType();
private static final PyObject NUMBA_GUVECTORIZED_FUNC_TYPE = getNumbaGUVectorizedFuncType();

private static final PyModule dh_table_module = PyModule.importModule("deephaven.table");
private static final PyModule dh_udf_module = PyModule.importModule("deephaven._udf");

private static final Map<Character, Class<?>> numpyType2JavaClass = new HashMap<>();

static {
numpyType2JavaClass.put('b', byte.class);
numpyType2JavaClass.put('h', short.class);
numpyType2JavaClass.put('H', char.class);
numpyType2JavaClass.put('i', int.class);
numpyType2JavaClass.put('l', long.class);
numpyType2JavaClass.put('h', short.class);
numpyType2JavaClass.put('f', float.class);
numpyType2JavaClass.put('d', double.class);
numpyType2JavaClass.put('b', byte.class);
numpyType2JavaClass.put('?', boolean.class);
numpyType2JavaClass.put('U', String.class);
numpyType2JavaClass.put('M', Instant.class);
Expand Down Expand Up @@ -133,23 +134,21 @@ private void prepareSignature() {
pyCallable
+ " has multiple signatures; this is not currently supported for numba vectorized/guvectorized functions");
}
signature = params.get(0).getStringValue();
unwrapped = pyCallable;
// since vectorization doesn't support array type parameters, don't flag numba guvectorized as vectorized
numbaVectorized = isNumbaVectorized;
vectorized = isNumbaVectorized;
} else if (pyCallable.hasAttribute("dh_vectorized")) {
signature = pyCallable.getAttribute("signature").toString();
unwrapped = pyCallable.getAttribute("callable");
numbaVectorized = false;
vectorized = true;
} else {
signature = dh_table_module.call("_encode_signature", pyCallable).toString();
unwrapped = pyCallable;
numbaVectorized = false;
vectorized = false;
}
pyUdfDecoratedCallable = dh_table_module.call("_py_udf", unwrapped);
pyUdfDecoratedCallable = dh_udf_module.call("_py_udf", unwrapped);
signature = pyUdfDecoratedCallable.getAttribute("signature").toString();
}

@Override
Expand Down Expand Up @@ -199,7 +198,7 @@ public PyObject vectorizedCallable() {
if (numbaVectorized || vectorized) {
return pyCallable;
} else {
return dh_table_module.call("dh_vectorize", unwrapped);
return dh_udf_module.call("_dh_vectorize", unwrapped);
}
}

Expand Down
Loading

0 comments on commit 19b1d7e

Please sign in to comment.