diff --git a/src/data_traits.rs b/src/data_traits.rs index acf4b0b7a..f44e8ca75 100644 --- a/src/data_traits.rs +++ b/src/data_traits.rs @@ -766,3 +766,16 @@ impl<'a, A: 'a, B: 'a> RawDataSubst for CowRepr<'a, A> { } } } + +/// Plug the data element type of one owned array representation into another. +/// +/// For example, ` as Plug>::Type` has type `OwnedRepr`. +pub trait Plug { + type Type; +} +impl Plug for crate::OwnedRepr { + type Type = crate::OwnedRepr; +} +impl Plug for crate::OwnedArcRepr { + type Type = crate::OwnedArcRepr; +} diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 115cd2d71..e672635b6 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -16,6 +16,7 @@ use crate::imp_prelude::*; use crate::{arraytraits, DimMax}; use crate::argument_traits::AssignElem; +use crate::data_traits::Plug; use crate::dimension; use crate::dimension::IntoDimension; use crate::dimension::{ @@ -2586,15 +2587,29 @@ where /// map is performed as in [`mapv`]. /// /// Elements are visited in arbitrary order. - /// + /// + /// Note that the compiler will need some hint about the return type, which + /// is generic over [`DataOwned`], and can thus be an [`Array`] or + /// [`ArcArray`]. Example: + /// + /// ```rust + /// # use ndarray::{array, Array}; + /// let a = array![[1., 2., 3.]]; + /// let a_plus_one: Array<_, _> = a.mapv_into_any(|a| a + 1.); + /// ``` + /// /// [`mapv_into`]: ArrayBase::mapv_into /// [`mapv`]: ArrayBase::mapv - pub fn mapv_into_any(self, mut f: F) -> Array + pub fn mapv_into_any(self, mut f: F) -> ArrayBase where - S: DataMut, + S: DataMut, F: FnMut(A) -> B, A: Clone + 'static, B: 'static, + T: DataOwned + Plug, // lets us introspect on the types of array representations containing different data elements + >::Type: RawData, // required by mapv_into() + ArrayBase<>::Type, D>: From>, // required by into() to convert from the DataMut array representation of S to the DataOwned array representation of T + ArrayBase: From>, // required by mapv() { if core::any::TypeId::of::() == core::any::TypeId::of::() { // A and B are the same type. @@ -2604,16 +2619,21 @@ where // Safe because A and B are the same type. unsafe { unlimited_transmute::(b) } }; - // Delegate to mapv_into() using the wrapped closure. - // Convert output to a uniquely owned array of type Array. - let output = self.mapv_into(f).into_owned(); - // Change the return type from Array to Array. - // Again, safe because A and B are the same type. - unsafe { unlimited_transmute::, Array>(output) } + // Delegate to mapv_into() to map from element type A to type A. + let output = self.mapv_into(f); + // Convert from S's array representation to T's array representation. + // Suppose `T is `OwnedRepr`. + // Then `>::Type` is `OwnedRepr`. + let output: ArrayBase<>::Type, D> = output.into(); + // Map from T's array representation with data element type A to T's + // array representation with element type B. + // This is safe because A and B are the same type and the array + // representations are also the same higher kinded type. + unsafe { unlimited_transmute::>::Type, D>, ArrayBase>(output) } } else { // A and B are not the same type. // Fallback to mapv(). - self.mapv(f) + self.mapv(f).into() } } diff --git a/tests/array.rs b/tests/array.rs index e3922ea8d..20229c331 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -995,14 +995,40 @@ fn map1() { fn mapv_into_any_same_type() { let a: Array = array![[1., 2., 3.], [4., 5., 6.]]; let a_plus_one: Array = array![[2., 3., 4.], [5., 6., 7.]]; - assert_eq!(a.mapv_into_any(|a| a + 1.), a_plus_one); + let b: Array<_, _> = a.mapv_into_any(|a| a + 1.); + assert_eq!(b, a_plus_one); } #[test] fn mapv_into_any_diff_types() { let a: Array = array![[1., 2., 3.], [4., 5., 6.]]; let a_even: Array = array![[false, true, false], [true, false, true]]; - assert_eq!(a.mapv_into_any(|a| a.round() as i32 % 2 == 0), a_even); + let b: Array<_, _> = a.mapv_into_any(|a| a.round() as i32 % 2 == 0); + assert_eq!(b, a_even); +} + +#[test] +fn mapv_into_any_arcarray_same_type() { + let a: ArcArray = array![[1., 2., 3.], [4., 5., 6.]].into_shared(); + let a_plus_one: Array = array![[2., 3., 4.], [5., 6., 7.]]; + let b: ArcArray<_, _> = a.mapv_into_any(|a| a + 1.); + assert_eq!(b, a_plus_one); +} + +#[test] +fn mapv_into_any_arcarray_diff_types() { + let a: ArcArray = array![[1., 2., 3.], [4., 5., 6.]].into_shared(); + let a_even: Array = array![[false, true, false], [true, false, true]]; + let b: ArcArray<_, _> = a.mapv_into_any(|a| a.round() as i32 % 2 == 0); + assert_eq!(b, a_even); +} + +#[test] +fn mapv_into_any_diff_outer_types() { + let a: Array = array![[1., 2., 3.], [4., 5., 6.]]; + let a_plus_one: Array = array![[2., 3., 4.], [5., 6., 7.]]; + let b: ArcArray<_, _> = a.mapv_into_any(|a| a + 1.); + assert_eq!(b, a_plus_one); } #[test]