Skip to content

Commit

Permalink
Add reverse map iteration (#596)
Browse files Browse the repository at this point in the history
* Add reverse map iteration
  • Loading branch information
devsnek authored Mar 9, 2024
1 parent 63ca3fe commit f944588
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 25 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ versions.

### Added

- Map iterators are now [DoubleEndedIterators](https://doc.rust-lang.org/std/iter/trait.DoubleEndedIterator.html)
(#598), thus allowing being iterated in reverse using `.rev()`

### Fixed

### Changed
Expand Down
136 changes: 118 additions & 18 deletions rustler/src/types/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,40 +193,140 @@ impl<'a> Term<'a> {
}
}

pub struct MapIterator<'a> {
env: Env<'a>,
iter: map::ErlNifMapIterator,
struct SimpleMapIterator<'a> {
map: Term<'a>,
entry: map::MapIteratorEntry,
iter: Option<map::ErlNifMapIterator>,
last_key: Option<Term<'a>>,
done: bool,
}

impl<'a> MapIterator<'a> {
pub fn new(map: Term<'a>) -> Option<MapIterator<'a>> {
let env = map.get_env();
unsafe { map::map_iterator_create(env.as_c_arg(), map.as_c_arg()) }
.map(|iter| MapIterator { env, iter })
impl<'a> SimpleMapIterator<'a> {
fn next(&mut self) -> Option<(Term<'a>, Term<'a>)> {
if self.done {
return None;
}

let iter = loop {
match self.iter.as_mut() {
None => {
match unsafe {
map::map_iterator_create(
self.map.get_env().as_c_arg(),
self.map.as_c_arg(),
self.entry,
)
} {
Some(iter) => {
self.iter = Some(iter);
continue;
}
None => {
self.done = true;
return None;
}
}
}
Some(iter) => {
break iter;
}
}
};

let env = self.map.get_env();

unsafe {
match map::map_iterator_get_pair(env.as_c_arg(), iter) {
Some((key, value)) => {
match self.entry {
map::MapIteratorEntry::First => {
map::map_iterator_next(env.as_c_arg(), iter);
}
map::MapIteratorEntry::Last => {
map::map_iterator_prev(env.as_c_arg(), iter);
}
}
let key = Term::new(env, key);
self.last_key = Some(key);
Some((key, Term::new(env, value)))
}
None => {
self.done = true;
None
}
}
}
}
}

impl<'a> Drop for MapIterator<'a> {
impl<'a> Drop for SimpleMapIterator<'a> {
fn drop(&mut self) {
unsafe {
map::map_iterator_destroy(self.env.as_c_arg(), &mut self.iter);
if let Some(iter) = self.iter.as_mut() {
unsafe {
map::map_iterator_destroy(self.map.get_env().as_c_arg(), iter);
}
}
}
}

impl<'a> Iterator for MapIterator<'a> {
type Item = (Term<'a>, Term<'a>);
pub struct MapIterator<'a> {
forward: SimpleMapIterator<'a>,
reverse: SimpleMapIterator<'a>,
}

fn next(&mut self) -> Option<(Term<'a>, Term<'a>)> {
unsafe {
map::map_iterator_get_pair(self.env.as_c_arg(), &mut self.iter).map(|(key, value)| {
map::map_iterator_next(self.env.as_c_arg(), &mut self.iter);
(Term::new(self.env, key), Term::new(self.env, value))
impl<'a> MapIterator<'a> {
pub fn new(map: Term<'a>) -> Option<MapIterator<'a>> {
if map.is_map() {
Some(MapIterator {
forward: SimpleMapIterator {
map,
entry: map::MapIteratorEntry::First,
iter: None,
last_key: None,
done: false,
},
reverse: SimpleMapIterator {
map,
entry: map::MapIteratorEntry::Last,
iter: None,
last_key: None,
done: false,
},
})
} else {
None
}
}
}

impl<'a> Iterator for MapIterator<'a> {
type Item = (Term<'a>, Term<'a>);

fn next(&mut self) -> Option<Self::Item> {
self.forward.next().and_then(|(key, value)| {
if self.reverse.last_key == Some(key) {
self.forward.done = true;
self.reverse.done = true;
return None;
}
Some((key, value))
})
}
}

impl<'a> DoubleEndedIterator for MapIterator<'a> {
fn next_back(&mut self) -> Option<Self::Item> {
self.reverse.next().and_then(|(key, value)| {
if self.forward.last_key == Some(key) {
self.forward.done = true;
self.reverse.done = true;
return None;
}
Some((key, value))
})
}
}

impl<'a> Decoder<'a> for MapIterator<'a> {
fn decode(term: Term<'a>) -> NifResult<Self> {
match MapIterator::new(term) {
Expand Down
21 changes: 19 additions & 2 deletions rustler/src/wrapper/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,26 @@ pub unsafe fn map_update(
Some(result.assume_init())
}

pub unsafe fn map_iterator_create(env: NIF_ENV, map: NIF_TERM) -> Option<ErlNifMapIterator> {
#[derive(Clone, Copy, Debug)]
pub enum MapIteratorEntry {
First,
Last,
}

pub unsafe fn map_iterator_create(
env: NIF_ENV,
map: NIF_TERM,
entry: MapIteratorEntry,
) -> Option<ErlNifMapIterator> {
let mut iter = MaybeUninit::uninit();
let success = rustler_sys::enif_map_iterator_create(
env,
map,
iter.as_mut_ptr(),
ErlNifMapIteratorEntry::ERL_NIF_MAP_ITERATOR_HEAD,
match entry {
MapIteratorEntry::First => ErlNifMapIteratorEntry::ERL_NIF_MAP_ITERATOR_HEAD,
MapIteratorEntry::Last => ErlNifMapIteratorEntry::ERL_NIF_MAP_ITERATOR_TAIL,
},
);
if success == 0 {
None
Expand Down Expand Up @@ -103,6 +116,10 @@ pub unsafe fn map_iterator_next(env: NIF_ENV, iter: &mut ErlNifMapIterator) {
rustler_sys::enif_map_iterator_next(env, iter);
}

pub unsafe fn map_iterator_prev(env: NIF_ENV, iter: &mut ErlNifMapIterator) {
rustler_sys::enif_map_iterator_prev(env, iter);
}

pub unsafe fn make_map_from_arrays(
env: NIF_ENV,
keys: &[NIF_TERM],
Expand Down
3 changes: 2 additions & 1 deletion rustler_tests/lib/rustler_test.ex
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ defmodule RustlerTest do
def term_type(_term), do: err()

def sum_map_values(_), do: err()
def map_entries_sorted(_), do: err()
def map_entries(_), do: err()
def map_entries_reversed(_), do: err()
def map_from_arrays(_keys, _values), do: err()
def map_from_pairs(_pairs), do: err()
def map_generic(_), do: err()
Expand Down
3 changes: 2 additions & 1 deletion rustler_tests/native/rustler_test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ rustler::init!(
test_term::term_phash2_hash,
test_term::term_type,
test_map::sum_map_values,
test_map::map_entries_sorted,
test_map::map_entries,
test_map::map_entries_reversed,
test_map::map_from_arrays,
test_map::map_from_pairs,
test_map::map_generic,
Expand Down
18 changes: 16 additions & 2 deletions rustler_tests/native/rustler_test/src/test_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,28 @@ pub fn sum_map_values(iter: MapIterator) -> NifResult<i64> {
}

#[rustler::nif]
pub fn map_entries_sorted<'a>(env: Env<'a>, iter: MapIterator<'a>) -> NifResult<Vec<Term<'a>>> {
pub fn map_entries<'a>(env: Env<'a>, iter: MapIterator<'a>) -> NifResult<Vec<Term<'a>>> {
let mut vec = vec![];
for (key, value) in iter {
let key_string = key.decode::<String>()?;
vec.push((key_string, value));
}

vec.sort_by_key(|pair| pair.0.clone());
let erlang_pairs: Vec<Term> = vec
.into_iter()
.map(|(key, value)| make_tuple(env, &[key.encode(env), value]))
.collect();
Ok(erlang_pairs)
}

#[rustler::nif]
pub fn map_entries_reversed<'a>(env: Env<'a>, iter: MapIterator<'a>) -> NifResult<Vec<Term<'a>>> {
let mut vec = vec![];
for (key, value) in iter.rev() {
let key_string = key.decode::<String>()?;
vec.push((key_string, value));
}

let erlang_pairs: Vec<Term> = vec
.into_iter()
.map(|(key, value)| make_tuple(env, &[key.encode(env), value]))
Expand Down
7 changes: 6 additions & 1 deletion rustler_tests/test/map_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@ defmodule RustlerTest.MapTest do
end

test "map iteration with keys" do
entries = RustlerTest.map_entries(%{"d" => 0, "a" => 1, "b" => 7, "e" => 4, "c" => 6})

assert [{"a", 1}, {"b", 7}, {"c", 6}, {"d", 0}, {"e", 4}] ==
RustlerTest.map_entries_sorted(%{"d" => 0, "a" => 1, "b" => 7, "e" => 4, "c" => 6})
Enum.sort_by(entries, &elem(&1, 0))

assert Enum.reverse(entries) ==
RustlerTest.map_entries_reversed(%{"d" => 0, "a" => 1, "b" => 7, "e" => 4, "c" => 6})
end

test "map from arrays" do
Expand Down

0 comments on commit f944588

Please sign in to comment.