From a1fa94ac90dedca2b8fb27c162d394aada15ac40 Mon Sep 17 00:00:00 2001 From: Nicolas Cornu Date: Tue, 8 Oct 2024 15:56:58 +0200 Subject: [PATCH] Use nb::iterator instead of Py_Iterator --- src/nrnpython/nrnpy_p2h.cpp | 16 +++++++--------- src/utils/enumerate.h | 4 ++-- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/nrnpython/nrnpy_p2h.cpp b/src/nrnpython/nrnpy_p2h.cpp index f84cb29bcf..406cfa835e 100644 --- a/src/nrnpython/nrnpy_p2h.cpp +++ b/src/nrnpython/nrnpy_p2h.cpp @@ -14,6 +14,8 @@ #include "parse.hpp" #include + +#include "utils/enumerate.h" namespace nb = nanobind; static char* nrnpyerr_str(); @@ -937,15 +939,13 @@ static Object* py_alltoall_type(int size, int type) { // for alltoall, each rank handled identically // for scatter, root handled as list all, other ranks handled as None if (type == 1 || nrnmpi_myid == root) { // psrc is list of nhost items + nb::list psrc_list(psrc); scnt = new int[np]; for (int i = 0; i < np; ++i) { scnt[i] = 0; } - PyObject* iterator = PyObject_GetIter(psrc); - PyObject* p; - size_t bufsz = 100000; // 100k buffer to start with if (size > 0) { // or else the positive number specified bufsz = size; @@ -954,13 +954,13 @@ static Object* py_alltoall_type(int size, int type) { s = new char[bufsz]; } size_t curpos = 0; - for (size_t i = 0; (p = PyIter_Next(iterator)) != NULL; ++i) { - if (p == Py_None) { + for (auto&& [i, p]: enumerate(psrc_list)) { + if (p.is_none()) { scnt[i] = 0; - Py_DECREF(p); + p.dec_ref(); continue; } - auto b = pickle(p); + auto b = pickle(p.ptr()); if (size >= 0) { if (curpos + b.size() >= bufsz) { bufsz = bufsz * 2 + b.size(); @@ -977,9 +977,7 @@ static Object* py_alltoall_type(int size, int type) { } curpos += b.size(); scnt[i] = static_cast(b.size()); - Py_DECREF(p); } - Py_DECREF(iterator); // scatter equivalent to alltoall NONE list for not root ranks. } else if (type == 5 && nrnmpi_myid != root) { diff --git a/src/utils/enumerate.h b/src/utils/enumerate.h index 1a0b38f2ad..e9492887bc 100644 --- a/src/utils/enumerate.h +++ b/src/utils/enumerate.h @@ -99,7 +99,7 @@ constexpr auto enumerate(T&& iterable) { ++iter; } auto operator*() const { - return std::tie(i, *iter); + return std::forward_as_tuple(i, *iter); } }; struct iterable_wrapper { @@ -129,7 +129,7 @@ constexpr auto renumerate(T&& iterable) { ++iter; } auto operator*() const { - return std::tie(i, *iter); + return std::forward_as_tuple(i, *iter); } }; struct iterable_wrapper {