From 4647efcc45d96e530d41a3461cd9727656bc2ca3 Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Fri, 20 Sep 2024 10:29:20 +0900 Subject: [PATCH] improved ndarray conversion for JAX (fixes issue #729) --- src/nb_ndarray.cpp | 154 ++++++++++++++++++++++------------------- tests/py_stub_test.pyi | 4 +- 2 files changed, 83 insertions(+), 75 deletions(-) diff --git a/src/nb_ndarray.cpp b/src/nb_ndarray.cpp index 463f8741..2816c2cd 100644 --- a/src/nb_ndarray.cpp +++ b/src/nb_ndarray.cpp @@ -23,6 +23,17 @@ struct ndarray_handle { bool ro; }; +static void ndarray_capsule_destructor(PyObject *o) { + error_scope scope; // temporarily save any existing errors + managed_dltensor *mt = + (managed_dltensor *) PyCapsule_GetPointer(o, "dltensor"); + + if (mt) + ndarray_dec_ref((ndarray_handle *) mt->manager_ctx); + else + PyErr_Clear(); +} + static void nb_ndarray_dealloc(PyObject *self) { PyTypeObject *tp = Py_TYPE(self); ndarray_dec_ref(((nb_ndarray *) self)->th); @@ -123,12 +134,52 @@ static void nb_ndarray_releasebuffer(PyObject *, Py_buffer *view) { PyMem_Free(view->strides); } + +static PyObject *nb_ndarray_dlpack(PyObject *self, PyTypeObject *, + PyObject *const *, Py_ssize_t , + PyObject *) { + nb_ndarray *self_nd = (nb_ndarray *) self; + ndarray_handle *th = self_nd->th; + + PyObject *r = + PyCapsule_New(th->ndarray, "dltensor", ndarray_capsule_destructor); + if (r) + ndarray_inc_ref(th); + return r; +} + +static PyObject *nb_ndarray_dlpack_device(PyObject *self, PyTypeObject *, + PyObject *const *, Py_ssize_t , + PyObject *) { + nb_ndarray *self_nd = (nb_ndarray *) self; + dlpack::dltensor &t = self_nd->th->ndarray->dltensor; + PyObject *r = PyTuple_New(2); + PyObject *r0 = PyLong_FromLong(t.device.device_type); + PyObject *r1 = PyLong_FromLong(t.device.device_id); + if (!r || !r0 || !r1) { + Py_XDECREF(r); + Py_XDECREF(r0); + Py_XDECREF(r1); + return nullptr; + } + NB_TUPLE_SET_ITEM(r, 0, r0); + NB_TUPLE_SET_ITEM(r, 1, r1); + return r; +} + +static PyMethodDef nb_ndarray_members[] = { + { "__dlpack__", (PyCFunction) nb_ndarray_dlpack, METH_FASTCALL | METH_KEYWORDS, nullptr }, + { "__dlpack_device__", (PyCFunction) nb_ndarray_dlpack_device, METH_FASTCALL | METH_KEYWORDS, nullptr }, + { nullptr, nullptr, 0, nullptr } +}; + static PyTypeObject *nd_ndarray_tp() noexcept { PyTypeObject *tp = internals->nb_ndarray; if (NB_UNLIKELY(!tp)) { PyType_Slot slots[] = { { Py_tp_dealloc, (void *) nb_ndarray_dealloc }, + { Py_tp_methods, (void *) nb_ndarray_members }, #if PY_VERSION_HEX >= 0x03090000 { Py_bf_getbuffer, (void *) nd_ndarray_tpbuffer }, { Py_bf_releasebuffer, (void *) nb_ndarray_releasebuffer }, @@ -649,17 +700,6 @@ ndarray_handle *ndarray_create(void *value, size_t ndim, const size_t *shape_in, return result.release(); } -static void ndarray_capsule_destructor(PyObject *o) { - error_scope scope; // temporarily save any existing errors - managed_dltensor *mt = - (managed_dltensor *) PyCapsule_GetPointer(o, "dltensor"); - - if (mt) - ndarray_dec_ref((ndarray_handle *) mt->manager_ctx); - else - PyErr_Clear(); -} - PyObject *ndarray_export(ndarray_handle *th, int framework, rv_policy policy, cleanup_list *cleanup) noexcept { if (!th) @@ -706,79 +746,47 @@ PyObject *ndarray_export(ndarray_handle *th, int framework, } } - if (framework == numpy::value) { - try { - nb_ndarray *h = PyObject_New(nb_ndarray, nd_ndarray_tp()); - if (!h) - return nullptr; - h->th = th; - ndarray_inc_ref(th); - - object o = steal((PyObject *) h); - return module_::import_("numpy") - .attr("array")(o, arg("copy") = copy) - .release() - .ptr(); - } catch (const std::exception &e) { - PyErr_Format(PyExc_RuntimeError, - "nanobind::detail::ndarray_export(): could not " - "convert ndarray to NumPy array: %s", e.what()); - return nullptr; - } - } - - object package; - try { - switch (framework) { - case no_framework::value: - break; - - case pytorch::value: - package = module_::import_("torch.utils.dlpack"); - break; - - case tensorflow::value: - package = module_::import_("tensorflow.experimental.dlpack"); - break; - - case jax::value: - package = module_::import_("jax.dlpack"); - break; - - case cupy::value: - package = module_::import_("cupy"); - break; - - default: - check(false, "nanobind::detail::ndarray_export(): unknown " - "framework specified!"); - } - } catch (const std::exception &e) { - PyErr_Format(PyExc_RuntimeError, - "nanobind::detail::ndarray_export(): could not import ndarray " - "framework: %s", e.what()); - return nullptr; - } - object o; if (copy && framework == no_framework::value && th->self) { o = borrow(th->self); + } else if (framework == numpy::value || framework == jax::value) { + nb_ndarray *h = PyObject_New(nb_ndarray, nd_ndarray_tp()); + if (!h) + return nullptr; + h->th = th; + ndarray_inc_ref(th); + o = steal((PyObject *) h); } else { o = steal(PyCapsule_New(th->ndarray, "dltensor", ndarray_capsule_destructor)); ndarray_inc_ref(th); } + try { + if (framework == numpy::value) { + return module_::import_("numpy") + .attr("array")(o, arg("copy") = copy) + .release() + .ptr(); + } else { + const char *pkg_name; + switch (framework) { + case pytorch::value: pkg_name = "torch.utils.dlpack"; break; + case tensorflow::value: pkg_name = "tensorflow.experimental.dlpack"; break; + case jax::value: pkg_name = "jax.dlpack"; break; + case cupy::value: pkg_name = "cupy"; break; + default: pkg_name = nullptr; + } - if (package.is_valid()) { - try { - o = package.attr("from_dlpack")(o); - } catch (const std::exception &e) { - PyErr_Format(PyExc_RuntimeError, - "nanobind::detail::ndarray_export(): could not " - "import ndarray: %s", e.what()); - return nullptr; + if (pkg_name) + o = module_::import_(pkg_name).attr("from_dlpack")(o); } + } catch (const std::exception &e) { + PyErr_Format(PyExc_RuntimeError, + "nanobind::detail::ndarray_export(): could not " + "import ndarray: %s", + e.what()); + return nullptr; } if (copy) { diff --git a/tests/py_stub_test.pyi b/tests/py_stub_test.pyi index f9256f1f..9e6822d0 100644 --- a/tests/py_stub_test.pyi +++ b/tests/py_stub_test.pyi @@ -15,8 +15,8 @@ class AClass: @staticmethod def static_method(x): ... - @staticmethod - def class_method(x): ... + @classmethod + def class_method(cls, x): ... @overload def overloaded(self, x: int) -> None: