Skip to content

Commit

Permalink
Make mapv_into_any() work for ArcArray, resolves rust-ndarray#1280
Browse files Browse the repository at this point in the history
  • Loading branch information
benkay86 committed May 4, 2023
1 parent 0740695 commit e8a73e6
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 12 deletions.
13 changes: 13 additions & 0 deletions src/data_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -766,3 +766,16 @@ impl<'a, A: 'a, B: 'a> RawDataSubst<B> for CowRepr<'a, A> {
}
}
}

/// Plug the data element type of one owned array representation into another.
///
/// For example, `<OwnedRepr<f64> as Plug<f32>>::Type` has type `OwnedRepr<f32>`.
pub trait Plug<A> {
type Type;
}
impl <A, B> Plug<A> for crate::OwnedRepr<B> {
type Type = crate::OwnedRepr<A>;
}
impl <A, B> Plug<A> for crate::OwnedArcRepr<B> {
type Type = crate::OwnedArcRepr<A>;
}
40 changes: 30 additions & 10 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::imp_prelude::*;

use crate::{arraytraits, DimMax};
use crate::argument_traits::AssignElem;
use crate::data_traits::Plug;
use crate::dimension;
use crate::dimension::IntoDimension;
use crate::dimension::{
Expand Down Expand Up @@ -2586,15 +2587,29 @@ where
/// map is performed as in [`mapv`].
///
/// Elements are visited in arbitrary order.
///
///
/// Note that the compiler will need some hint about the return type, which
/// is generic over [`DataOwned`], and can thus be an [`Array`] or
/// [`ArcArray`]. Example:
///
/// ```rust
/// # use ndarray::{array, Array};
/// let a = array![[1., 2., 3.]];
/// let a_plus_one: Array<_, _> = a.mapv_into_any(|a| a + 1.);
/// ```
///
/// [`mapv_into`]: ArrayBase::mapv_into
/// [`mapv`]: ArrayBase::mapv
pub fn mapv_into_any<B, F>(self, mut f: F) -> Array<B, D>
pub fn mapv_into_any<B, F, T/* , TT*/>(self, mut f: F) -> ArrayBase<T, D>
where
S: DataMut,
S: DataMut<Elem = A>,
F: FnMut(A) -> B,
A: Clone + 'static,
B: 'static,
T: DataOwned<Elem = B> + Plug<A>, // lets us introspect on the types of array representations containing different data elements
<T as Plug<A>>::Type: RawData, // required by mapv_into()
ArrayBase<<T as Plug<A>>::Type, D>: From<ArrayBase<S, D>>, // required by into() to convert from the DataMut array representation of S to the DataOwned array representation of T
ArrayBase<T, D>: From<Array<B, D>>, // required by mapv()
{
if core::any::TypeId::of::<A>() == core::any::TypeId::of::<B>() {
// A and B are the same type.
Expand All @@ -2604,16 +2619,21 @@ where
// Safe because A and B are the same type.
unsafe { unlimited_transmute::<B, A>(b) }
};
// Delegate to mapv_into() using the wrapped closure.
// Convert output to a uniquely owned array of type Array<A, D>.
let output = self.mapv_into(f).into_owned();
// Change the return type from Array<A, D> to Array<B, D>.
// Again, safe because A and B are the same type.
unsafe { unlimited_transmute::<Array<A, D>, Array<B, D>>(output) }
// Delegate to mapv_into() to map from element type A to type A.
let output = self.mapv_into(f);
// Convert from S's array representation to T's array representation.
// Suppose `T is `OwnedRepr<B>`.
// Then `<T as Plug<A>>::Type` is `OwnedRepr<A>`.
let output: ArrayBase<<T as Plug<A>>::Type, D> = output.into();
// Map from T's array representation with data element type A to T's
// array representation with element type B.
// This is safe because A and B are the same type and the array
// representations are also the same higher kinded type.
unsafe { unlimited_transmute::<ArrayBase<<T as Plug<A>>::Type, D>, ArrayBase<T,D>>(output) }
} else {
// A and B are not the same type.
// Fallback to mapv().
self.mapv(f)
self.mapv(f).into()
}
}

Expand Down
30 changes: 28 additions & 2 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -995,14 +995,40 @@ fn map1() {
fn mapv_into_any_same_type() {
let a: Array<f64, _> = array![[1., 2., 3.], [4., 5., 6.]];
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
assert_eq!(a.mapv_into_any(|a| a + 1.), a_plus_one);
let b: Array<_, _> = a.mapv_into_any(|a| a + 1.);
assert_eq!(b, a_plus_one);
}

#[test]
fn mapv_into_any_diff_types() {
let a: Array<f64, _> = array![[1., 2., 3.], [4., 5., 6.]];
let a_even: Array<bool, _> = array![[false, true, false], [true, false, true]];
assert_eq!(a.mapv_into_any(|a| a.round() as i32 % 2 == 0), a_even);
let b: Array<_, _> = a.mapv_into_any(|a| a.round() as i32 % 2 == 0);
assert_eq!(b, a_even);
}

#[test]
fn mapv_into_any_arcarray_same_type() {
let a: ArcArray<f64, _> = array![[1., 2., 3.], [4., 5., 6.]].into_shared();
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
let b: ArcArray<_, _> = a.mapv_into_any(|a| a + 1.);
assert_eq!(b, a_plus_one);
}

#[test]
fn mapv_into_any_arcarray_diff_types() {
let a: ArcArray<f64, _> = array![[1., 2., 3.], [4., 5., 6.]].into_shared();
let a_even: Array<bool, _> = array![[false, true, false], [true, false, true]];
let b: ArcArray<_, _> = a.mapv_into_any(|a| a.round() as i32 % 2 == 0);
assert_eq!(b, a_even);
}

#[test]
fn mapv_into_any_diff_outer_types() {
let a: Array<f64, _> = array![[1., 2., 3.], [4., 5., 6.]];
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
let b: ArcArray<_, _> = a.mapv_into_any(|a| a + 1.);
assert_eq!(b, a_plus_one);
}

#[test]
Expand Down

0 comments on commit e8a73e6

Please sign in to comment.