diff --git a/docs/api_extra.rst b/docs/api_extra.rst index f310497a8..977dde684 100644 --- a/docs/api_extra.rst +++ b/docs/api_extra.rst @@ -634,13 +634,13 @@ section `. .. cpp:function:: size_t itemsize() const Return the size of a single array element in bytes. The returned value - is rounded to the next full byte in case of bit-level representations + is rounded up to the next full byte in case of bit-level representations (query :cpp:member:`dtype::bits` for bit-level granularity). .. cpp:function:: size_t nbytes() const Return the size of the entire array bytes. The returned value is rounded - to the next full byte in case of bit-level representations. + up to the next full byte in case of bit-level representations. .. cpp:function:: size_t shape(size_t i) const @@ -648,7 +648,7 @@ section `. .. cpp:function:: int64_t stride(size_t i) const - Return the stride of dimension `i`. + Return the stride (in number of elements) of dimension `i`. .. cpp:function:: const int64_t* shape_ptr() const diff --git a/src/nb_ndarray.cpp b/src/nb_ndarray.cpp index 46432833c..aad12aa4e 100644 --- a/src/nb_ndarray.cpp +++ b/src/nb_ndarray.cpp @@ -262,8 +262,14 @@ static PyObject *dlpack_from_buffer_protocol(PyObject *o, bool ro) { scoped_pymalloc strides((size_t) view->ndim); scoped_pymalloc shape((size_t) view->ndim); + const int64_t itemsize = static_cast(view->itemsize); for (size_t i = 0; i < (size_t) view->ndim; ++i) { - strides[i] = (int64_t) (view->strides[i] / view->itemsize); + int64_t stride = view->strides[i] / itemsize; + if (stride * itemsize != view->strides[i]) { + PyBuffer_Release(view.get()); + return nullptr; + } + strides[i] = stride; shape[i] = (int64_t) view->shape[i]; } diff --git a/tests/test_ndarray.cpp b/tests/test_ndarray.cpp index fef267771..87ac291db 100644 --- a/tests/test_ndarray.cpp +++ b/tests/test_ndarray.cpp @@ -44,6 +44,10 @@ NB_MODULE(test_ndarray_ext, m) { return t.nbytes(); }, "array"_a.noconvert()); + m.def("get_stride", [](const nb::ndarray<> &t, size_t i) { + return t.stride(i); + }, "array"_a.noconvert(), "i"_a); + m.def("check_shape_ptr", [](const nb::ndarray<> &t) { std::vector shape(t.ndim()); std::copy(t.shape_ptr(), t.shape_ptr() + t.ndim(), shape.begin()); diff --git a/tests/test_ndarray.py b/tests/test_ndarray.py index 2bc7fb734..2c9e0cbeb 100644 --- a/tests/test_ndarray.py +++ b/tests/test_ndarray.py @@ -558,7 +558,7 @@ def test28_reference_internal(): assert msg in str(excinfo.value) @needs_numpy -def test29_force_contig_pytorch(): +def test29_force_contig_numpy(): a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) b = t.make_contig(a) assert b is a @@ -656,3 +656,28 @@ def __dlpack__(self): arr = DLPackWrapper(np.zeros((1))) assert t.check(arr) + +@needs_numpy +def test37_noninteger_stride(): + a = np.array([[1, 2, 3, 4, 0, 0], [5, 6, 7, 8, 0, 0]], dtype=np.float32) + s = a[:, 0:4] # slice + t.pass_float32(s) + assert t.get_stride(s, 0) == 6; + assert t.get_stride(s, 1) == 1; + v = s.view(np.complex64) + t.pass_complex64(v) + assert t.get_stride(v, 0) == 3; + assert t.get_stride(v, 1) == 1; + + a = np.array([[1, 2, 3, 4, 0], [5, 6, 7, 8, 0]], dtype=np.float32) + s = a[:, 0:4] # slice + t.pass_float32(s) + assert t.get_stride(s, 0) == 5; + assert t.get_stride(s, 1) == 1; + v = s.view(np.complex64) + with pytest.raises(TypeError) as excinfo: + t.pass_complex64(v) + assert 'incompatible function arguments' in str(excinfo.value) + with pytest.raises(TypeError) as excinfo: + t.get_stride(v, 0); + assert 'incompatible function arguments' in str(excinfo.value) diff --git a/tests/test_ndarray_ext.pyi.ref b/tests/test_ndarray_ext.pyi.ref index 91ee07841..69a604ba7 100644 --- a/tests/test_ndarray_ext.pyi.ref +++ b/tests/test_ndarray_ext.pyi.ref @@ -67,6 +67,8 @@ def get_shape(array: Annotated[ArrayLike, dict(writable=False)]) -> list: ... def get_size(array: ArrayLike) -> int: ... +def get_stride(array: ArrayLike, i: int) -> int: ... + def implicit(array: Annotated[ArrayLike, dict(dtype='float32', order='C', shape=(2, 2))]) -> int: ... @overload