diff --git a/Project.toml b/Project.toml index ea6297b8..2354043c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ManifoldsBase" uuid = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb" authors = ["Seth Axen ", "Mateusz Baran ", "Ronny Bergmann ", "Antoine Levitt "] -version = "0.13.12" +version = "0.13.13" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/ManifoldsBase.jl b/src/ManifoldsBase.jl index f1534032..1ac56dc9 100644 --- a/src/ManifoldsBase.jl +++ b/src/ManifoldsBase.jl @@ -77,6 +77,17 @@ function allocate(::AbstractManifold, a, T::Type, dim1::Integer, dims::Integer.. end allocate(::AbstractManifold, a, T::Type, dims::Tuple) = allocate(a, T, dims) +""" + _pick_basic_allocation_argument(::AbstractManifold, f, x...) + +Pick which one of elements of `x` should be used as a basis for allocation in the +`allocate_result(M::AbstractManifold, f, x...)` method. This can be specialized to, for +example, skip `Identity` arguments in Manifolds.jl group-related functions. +""" +function _pick_basic_allocation_argument(::AbstractManifold, f, x...) + return x[1] +end + """ allocate_result(M::AbstractManifold, f, x...) @@ -88,7 +99,8 @@ isomorphisms. """ @inline function allocate_result(M::AbstractManifold, f, x...) T = allocate_result_type(M, f, x) - return allocate(M, x[1], T) + picked = _pick_basic_allocation_argument(M, f, x...) + return allocate(M, picked, T) end @inline function allocate_result(M::AbstractManifold, f) T = allocate_result_type(M, f, ()) diff --git a/src/PowerManifold.jl b/src/PowerManifold.jl index 5a1cd573..478ca72a 100644 --- a/src/PowerManifold.jl +++ b/src/PowerManifold.jl @@ -130,8 +130,11 @@ const PowerManifoldNestedReplacing = AbstractPowerManifold{ NestedReplacingPowerRepresentation, } where {𝔽} -_access_nested(x, i::Int) = x[i] -_access_nested(x, i::Tuple) = x[i...] +# _access_nested(::AbstractManifold, x, i::Tuple) can be overloaded to achieve +# manifold-specific nested element access (for example to `Identity` on power manifolds). +@inline _access_nested(M::AbstractManifold, x, i::Int) = _access_nested(M, x, (i,)) +@inline _access_nested(::AbstractManifold, x, i::Tuple) = _access_nested(x, i) +@inline _access_nested(x, i::Tuple) = x[i...] function Base.:^( M::PowerManifold{ @@ -150,7 +153,7 @@ function allocate_result(M::PowerManifoldNested, f, x...) return allocate(M, x[1]) else return [ - allocate_result(M.manifold, f, map(y -> _access_nested(y, i), x)...) for + allocate_result(M.manifold, f, map(y -> _access_nested(M, y, i), x)...) for i in get_iterator(M) ] end @@ -199,7 +202,9 @@ function allocate_result( ) end function allocate_result(M::PowerManifoldNested, f::typeof(get_vector), p, X) - return [allocate_result(M.manifold, f, _access_nested(p, i)) for i in get_iterator(M)] + return [ + allocate_result(M.manifold, f, _access_nested(M, p, i)) for i in get_iterator(M) + ] end function allocate_result(::PowerManifoldNestedReplacing, ::typeof(get_vector), p, X) return copy(p) @@ -452,7 +457,7 @@ function get_coordinates( M.manifold, _read(M, rep_size, p, i), _read(M, rep_size, X, i), - _access_nested(B.data.bases, i), + _access_nested(M, B.data.bases, i), ) for i in get_iterator(M) ] return reduce(vcat, reshape(vs, length(vs))) @@ -492,7 +497,7 @@ function get_coordinates!( view(c, v_iter:(v_iter + dim - 1)), _read(M, rep_size, p, i), _read(M, rep_size, X, i), - _access_nested(B.data.bases, i), + _access_nested(M, B.data.bases, i), ) v_iter += dim end @@ -532,7 +537,7 @@ function get_vector!( _write(M, rep_size, Y, i), _read(M, rep_size, p, i), c[v_iter:(v_iter + dim - 1)], - _access_nested(B.data.bases, i), + _access_nested(M, B.data.bases, i), ) v_iter += dim end @@ -553,7 +558,7 @@ function get_vector!( M.manifold, _read(M, rep_size, p, i), c[v_iter:(v_iter + dim - 1)], - _access_nested(B.data.bases, i), + _access_nested(M, B.data.bases, i), ) v_iter += dim end @@ -604,7 +609,7 @@ function _get_vectors( rep_size = representation_size(M.manifold) vs = typeof(zero_tv)[] for i in get_iterator(M) - b_i = _access_nested(B.data.bases, i) + b_i = _access_nested(M, B.data.bases, i) p_i = _read(M, rep_size, p, i) for v in b_i.data new_v = copy(M, p, zero_tv) @@ -622,7 +627,7 @@ function _get_vectors( zero_tv = zero_vector(M, p) vs = typeof(zero_tv)[] for i in get_iterator(M) - b_i = _access_nested(B.data.bases, i) + b_i = _access_nested(M, B.data.bases, i) for v in b_i.data new_v = copy(M, p, zero_tv) new_v[i...] = v