Skip to content

Commit

Permalink
improved ndarray conversion for JAX (fixes issue #729)
Browse files Browse the repository at this point in the history
  • Loading branch information
wjakob committed Sep 20, 2024
1 parent c1be430 commit 4647efc
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 75 deletions.
154 changes: 81 additions & 73 deletions src/nb_ndarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ struct ndarray_handle {
bool ro;
};

static void ndarray_capsule_destructor(PyObject *o) {
error_scope scope; // temporarily save any existing errors
managed_dltensor *mt =
(managed_dltensor *) PyCapsule_GetPointer(o, "dltensor");

if (mt)
ndarray_dec_ref((ndarray_handle *) mt->manager_ctx);
else
PyErr_Clear();
}

static void nb_ndarray_dealloc(PyObject *self) {
PyTypeObject *tp = Py_TYPE(self);
ndarray_dec_ref(((nb_ndarray *) self)->th);
Expand Down Expand Up @@ -123,12 +134,52 @@ static void nb_ndarray_releasebuffer(PyObject *, Py_buffer *view) {
PyMem_Free(view->strides);
}


static PyObject *nb_ndarray_dlpack(PyObject *self, PyTypeObject *,
PyObject *const *, Py_ssize_t ,
PyObject *) {
nb_ndarray *self_nd = (nb_ndarray *) self;
ndarray_handle *th = self_nd->th;

PyObject *r =
PyCapsule_New(th->ndarray, "dltensor", ndarray_capsule_destructor);
if (r)
ndarray_inc_ref(th);
return r;
}

static PyObject *nb_ndarray_dlpack_device(PyObject *self, PyTypeObject *,
PyObject *const *, Py_ssize_t ,
PyObject *) {
nb_ndarray *self_nd = (nb_ndarray *) self;
dlpack::dltensor &t = self_nd->th->ndarray->dltensor;
PyObject *r = PyTuple_New(2);
PyObject *r0 = PyLong_FromLong(t.device.device_type);
PyObject *r1 = PyLong_FromLong(t.device.device_id);
if (!r || !r0 || !r1) {
Py_XDECREF(r);
Py_XDECREF(r0);
Py_XDECREF(r1);
return nullptr;
}
NB_TUPLE_SET_ITEM(r, 0, r0);
NB_TUPLE_SET_ITEM(r, 1, r1);
return r;
}

static PyMethodDef nb_ndarray_members[] = {
{ "__dlpack__", (PyCFunction) nb_ndarray_dlpack, METH_FASTCALL | METH_KEYWORDS, nullptr },
{ "__dlpack_device__", (PyCFunction) nb_ndarray_dlpack_device, METH_FASTCALL | METH_KEYWORDS, nullptr },

This comment has been minimized.

Copy link
@wojdyr

wojdyr Sep 21, 2024

Contributor

This triggers a warning about function casting (GCC 14, -Wextra)

/home/wojdyr/.local/lib/python3.12/site-packages/nanobind/src/nb_ndarray.cpp:171:20: warning: cast between incompatible function types from ‘PyObject* (*)(PyObject*, PyTypeObject*, PyObject* const*, Py_ssize_t, PyObject*)’ {aka ‘_object* (*)(_object*, _typeobject*, _object* const*, long int, _object*)’} to ‘PyCFunction’ {aka ‘_object* (*)(_object*, _object*)’} [-Wcast-function-type]
  171 |    { "__dlpack__", (PyCFunction) nb_ndarray_dlpack, METH_FASTCALL | METH_KEYWORDS, nullptr },
      |                    ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/home/wojdyr/.local/lib/python3.12/site-packages/nanobind/src/nb_ndarray.cpp:172:27: warning: cast between incompatible function types from ‘PyObject* (*)(PyObject*, PyTypeObject*, PyObject* const*, Py_ssize_t, PyObject*)’ {aka ‘_object* (*)(_object*, _typeobject*, _object* const*, long int, _object*)’} to ‘PyCFunction’ {aka ‘_object* (*)(_object*, _object*)’} [-Wcast-function-type]
  172 |    { "__dlpack_device__", (PyCFunction) nb_ndarray_dlpack_device, METH_FASTCALL | METH_KEYWORDS, nullptr },
      |                           ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

This comment has been minimized.

Copy link
@wjakob

wjakob Sep 21, 2024

Author Owner

Does commit 99bad51 fix this issue?

This comment has been minimized.

Copy link
@wojdyr

wojdyr Sep 21, 2024

Contributor

yes, thanks!

{ nullptr, nullptr, 0, nullptr }
};

static PyTypeObject *nd_ndarray_tp() noexcept {
PyTypeObject *tp = internals->nb_ndarray;

if (NB_UNLIKELY(!tp)) {
PyType_Slot slots[] = {
{ Py_tp_dealloc, (void *) nb_ndarray_dealloc },
{ Py_tp_methods, (void *) nb_ndarray_members },
#if PY_VERSION_HEX >= 0x03090000
{ Py_bf_getbuffer, (void *) nd_ndarray_tpbuffer },
{ Py_bf_releasebuffer, (void *) nb_ndarray_releasebuffer },
Expand Down Expand Up @@ -649,17 +700,6 @@ ndarray_handle *ndarray_create(void *value, size_t ndim, const size_t *shape_in,
return result.release();
}

static void ndarray_capsule_destructor(PyObject *o) {
error_scope scope; // temporarily save any existing errors
managed_dltensor *mt =
(managed_dltensor *) PyCapsule_GetPointer(o, "dltensor");

if (mt)
ndarray_dec_ref((ndarray_handle *) mt->manager_ctx);
else
PyErr_Clear();
}

PyObject *ndarray_export(ndarray_handle *th, int framework,
rv_policy policy, cleanup_list *cleanup) noexcept {
if (!th)
Expand Down Expand Up @@ -706,79 +746,47 @@ PyObject *ndarray_export(ndarray_handle *th, int framework,
}
}

if (framework == numpy::value) {
try {
nb_ndarray *h = PyObject_New(nb_ndarray, nd_ndarray_tp());
if (!h)
return nullptr;
h->th = th;
ndarray_inc_ref(th);

object o = steal((PyObject *) h);
return module_::import_("numpy")
.attr("array")(o, arg("copy") = copy)
.release()
.ptr();
} catch (const std::exception &e) {
PyErr_Format(PyExc_RuntimeError,
"nanobind::detail::ndarray_export(): could not "
"convert ndarray to NumPy array: %s", e.what());
return nullptr;
}
}

object package;
try {
switch (framework) {
case no_framework::value:
break;

case pytorch::value:
package = module_::import_("torch.utils.dlpack");
break;

case tensorflow::value:
package = module_::import_("tensorflow.experimental.dlpack");
break;

case jax::value:
package = module_::import_("jax.dlpack");
break;

case cupy::value:
package = module_::import_("cupy");
break;

default:
check(false, "nanobind::detail::ndarray_export(): unknown "
"framework specified!");
}
} catch (const std::exception &e) {
PyErr_Format(PyExc_RuntimeError,
"nanobind::detail::ndarray_export(): could not import ndarray "
"framework: %s", e.what());
return nullptr;
}

object o;
if (copy && framework == no_framework::value && th->self) {
o = borrow(th->self);
} else if (framework == numpy::value || framework == jax::value) {
nb_ndarray *h = PyObject_New(nb_ndarray, nd_ndarray_tp());
if (!h)
return nullptr;
h->th = th;
ndarray_inc_ref(th);
o = steal((PyObject *) h);
} else {
o = steal(PyCapsule_New(th->ndarray, "dltensor",
ndarray_capsule_destructor));
ndarray_inc_ref(th);
}

try {
if (framework == numpy::value) {
return module_::import_("numpy")
.attr("array")(o, arg("copy") = copy)
.release()
.ptr();
} else {
const char *pkg_name;
switch (framework) {
case pytorch::value: pkg_name = "torch.utils.dlpack"; break;
case tensorflow::value: pkg_name = "tensorflow.experimental.dlpack"; break;
case jax::value: pkg_name = "jax.dlpack"; break;
case cupy::value: pkg_name = "cupy"; break;
default: pkg_name = nullptr;
}

if (package.is_valid()) {
try {
o = package.attr("from_dlpack")(o);
} catch (const std::exception &e) {
PyErr_Format(PyExc_RuntimeError,
"nanobind::detail::ndarray_export(): could not "
"import ndarray: %s", e.what());
return nullptr;
if (pkg_name)
o = module_::import_(pkg_name).attr("from_dlpack")(o);
}
} catch (const std::exception &e) {
PyErr_Format(PyExc_RuntimeError,
"nanobind::detail::ndarray_export(): could not "
"import ndarray: %s",
e.what());
return nullptr;
}

if (copy) {
Expand Down
4 changes: 2 additions & 2 deletions tests/py_stub_test.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class AClass:
@staticmethod
def static_method(x): ...

@staticmethod
def class_method(x): ...
@classmethod
def class_method(cls, x): ...

@overload
def overloaded(self, x: int) -> None:
Expand Down

0 comments on commit 4647efc

Please sign in to comment.