diff --git a/Project.toml b/Project.toml index 8a5e5489..5254b6ca 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ManifoldsBase" uuid = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb" authors = ["Seth Axen ", "Mateusz Baran ", "Ronny Bergmann ", "Antoine Levitt "] -version = "0.12.9" +version = "0.12.10" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/ManifoldsBase.jl b/src/ManifoldsBase.jl index e474d0af..e89e177e 100644 --- a/src/ManifoldsBase.jl +++ b/src/ManifoldsBase.jl @@ -15,7 +15,8 @@ import Base: show, +, -, - * + *, + == import LinearAlgebra: dot, norm, det, cross, I, UniformScaling, Diagonal import Markdown: @doc_str @@ -62,6 +63,7 @@ allocate(a::AbstractArray{<:AbstractArray}, T::Type) = map(t -> allocate(t, T), allocate(a::NTuple{N,AbstractArray} where {N}) = map(allocate, a) allocate(a::NTuple{N,AbstractArray} where {N}, T::Type) = map(t -> allocate(t, T), a) + """ allocate_result(M::AbstractManifold, f, x...) @@ -100,7 +102,6 @@ Compute the angle between tangent vectors `X` and `Y` at point `p` from the function angle(M::AbstractManifold, p, X, Y) return acos(real(inner(M, p, X, Y)) / norm(M, p, X) / norm(M, p, Y)) end - """ base_manifold(M::AbstractManifold, depth = Val(-1)) @@ -531,7 +532,7 @@ The size of an array representing a point on [`AbstractManifold`](@ref) `M`. Returns `nothing` by default indicating that points are not represented using an `AbstractArray`. """ -function representation_size(M::AbstractManifold) +function representation_size(::AbstractManifold) return nothing end @@ -569,6 +570,7 @@ include("vector_transport.jl") include("DecoratorManifold.jl") include("bases.jl") include("vector_spaces.jl") +include("point_vector_fallbacks.jl") include("ValidationManifold.jl") include("EmbeddedManifold.jl") include("DefaultManifold.jl") diff --git a/src/ValidationManifold.jl b/src/ValidationManifold.jl index 5ab0fae4..3355f65d 100644 --- a/src/ValidationManifold.jl +++ b/src/ValidationManifold.jl @@ -57,26 +57,9 @@ This distinguished the value from [`ValidationMPoint`](@ref)s vectors of other t """ const ValidationCoTVector = ValidationFibreVector{CotangentSpaceType} +@eval @manifold_vector_forwards ValidationFibreVector{TType} TType value -function (+)(X::ValidationFibreVector{TType}, Y::ValidationFibreVector{TType}) where {TType} - return ValidationFibreVector{TType}(X.value + Y.value) -end -function (-)(X::ValidationFibreVector{TType}, Y::ValidationFibreVector{TType}) where {TType} - return ValidationFibreVector{TType}(X.value - Y.value) -end -(-)(X::ValidationFibreVector{TType}) where {TType} = ValidationFibreVector{TType}(-X.value) -function (*)(a::Number, X::ValidationFibreVector{TType}) where {TType} - return ValidationFibreVector{TType}(a * X.value) -end - -allocate(p::ValidationMPoint) = ValidationMPoint(allocate(p.value)) -allocate(p::ValidationMPoint, ::Type{T}) where {T} = ValidationMPoint(allocate(p.value, T)) -function allocate(X::ValidationFibreVector{TType}) where {TType} - return ValidationFibreVector{TType}(allocate(X.value)) -end -function allocate(X::ValidationFibreVector{TType}, ::Type{T}) where {TType,T} - return ValidationFibreVector{TType}(allocate(X.value, T)) -end +@eval @manifold_element_forwards ValidationMPoint value """ array_value(p) @@ -422,9 +405,7 @@ function mid_point!(M::ValidationManifold, q, p1, p2; kwargs...) end number_eltype(::Type{ValidationMPoint{V}}) where {V} = number_eltype(V) -number_eltype(p::ValidationMPoint) = number_eltype(p.value) number_eltype(::Type{ValidationFibreVector{TType,V}}) where {TType,V} = number_eltype(V) -number_eltype(X::ValidationFibreVector) = number_eltype(X.value) function project!(M::ValidationManifold, Y, p, X; kwargs...) is_point(M, p, true; kwargs...) @@ -433,15 +414,6 @@ function project!(M::ValidationManifold, Y, p, X; kwargs...) return Y end -similar(p::ValidationMPoint) = ValidationMPoint(similar(p.value)) -similar(p::ValidationMPoint, ::Type{T}) where {T} = ValidationMPoint(similar(p.value, T)) -function similar(X::ValidationFibreVector{TType}) where {TType} - return ValidationFibreVector{TType}(similar(X.value)) -end -function similar(X::ValidationFibreVector{TType}, ::Type{T}) where {TType,T} - return ValidationFibreVector{TType}(similar(X.value, T)) -end - function vector_transport_along!( M::ValidationManifold, Y, diff --git a/src/bases.jl b/src/bases.jl index a45d3828..b49a2051 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -285,6 +285,13 @@ const DISAMBIGUATION_COTANGENT_BASIS_TYPES = [ DefaultOrthogonalBasis{<:Any,CotangentSpaceType}, ] +""" + allocate_coordinates(M::AbstractManifold, p, T, n::Int) + +Allocate vector of coordinates of length `n` of type `T` of a vector at point `p` +on manifold `M`. +""" +allocate_coordinates(M::AbstractManifold, p, T, n::Int) = allocate(p, T, n) function allocate_result( M::AbstractManifold, @@ -294,7 +301,7 @@ function allocate_result( B::AbstractBasis, ) T = allocate_result_type(M, f, (p, X)) - return allocate(p, T, number_of_coordinates(M, B)) + return allocate_coordinates(M, p, T, number_of_coordinates(M, B)) end function allocate_result( @@ -305,7 +312,7 @@ function allocate_result( B::CachedBasis, ) T = allocate_result_type(M, f, (p, X)) - return allocate(p, T, number_of_coordinates(M, B)) + return allocate_coordinates(M, p, T, number_of_coordinates(M, B)) end @inline function allocate_result_type( diff --git a/src/maintypes.jl b/src/maintypes.jl index 90ce80aa..2e37a56b 100644 --- a/src/maintypes.jl +++ b/src/maintypes.jl @@ -23,7 +23,13 @@ abstract type AbstractManifold{𝔽} end Type for a point on a manifold. While a [`AbstractManifold`](@ref) does not necessarily require this type, for example when it is implemented for `Vector`s or `Matrix` type elements, this type -can be used for more complicated representations, semantic verification, or even dispatch -for different representations of points on a manifold. +can be used either +* for more complicated representations, +* semantic verification, or +* even dispatch for different representations of points on a manifold. + +Since semantic verification and different representations usually might still only store a +matrix internally, it is possible to use [`@manifold_element_forwards`](@ref) and +[`@default_manifold_fallbacks`](@ref) to reduce implementation overhead. """ abstract type AbstractManifoldPoint end diff --git a/src/point_vector_fallbacks.jl b/src/point_vector_fallbacks.jl new file mode 100644 index 00000000..111d1434 --- /dev/null +++ b/src/point_vector_fallbacks.jl @@ -0,0 +1,325 @@ + +""" + manifold_element_forwards(T, field::Symbol) + manifold_element_forwards(T, Twhere, field::Symbol) + +Introduce basic fallbacks for type `T` (which can be a subtype of `Twhere`) that represents +points or vectors for a manifold. +Fallbacks will work by forwarding to the field passed in `field`` + +List of forwarded functions: +* [`allocate`](@ref), +* [`copy`](@ref), +* [`copyto!`](@ref), +* [`number_eltype`](@ref) (only for values, not the type itself), +* `similar`, +* `==`. +""" +macro manifold_element_forwards(T, field::Symbol) + return esc(quote + @manifold_element_forwards ($T) _ ($field) + end) +end +macro manifold_element_forwards(T, Twhere, field::Symbol) + return esc( + quote + ManifoldsBase.allocate(p::$T) where {$Twhere} = $T(allocate(p.$field)) + function ManifoldsBase.allocate(p::$T, ::Type{P}) where {P,$Twhere} + return $T(allocate(p.$field, P)) + end + + @inline Base.copy(p::$T) where {$Twhere} = $T(copy(p.$field)) + + function Base.copyto!(q::$T, p::$T) where {$Twhere} + copyto!(q.$field, p.$field) + return q + end + + function ManifoldsBase.number_eltype(p::$T) where {$Twhere} + return typeof(one(eltype(p.$field))) + end + + Base.similar(p::$T) where {$Twhere} = $T(similar(p.$field)) + Base.similar(p::$T, ::Type{P}) where {P,$Twhere} = $T(similar(p.$field, P)) + + Base.:(==)(p::$T, q::$T) where {$Twhere} = (p.$field == q.$field) + end, + ) +end + + +""" + default_manifold_fallbacks(TM, TP, TV, pfield::Symbol, vfield::Symbol) + +Introduce default fallbacks for all basic functions on manifolds, for manifold of type `TM`, +points of type `TP`, tangent vectors of type `TV`, with forwarding to fields `pfield` and +`vfield` for point and tangent vector functions, respectively. +""" +macro default_manifold_fallbacks(TM, TP, TV, pfield::Symbol, vfield::Symbol) + block = quote + function ManifoldsBase.allocate_coordinates(M::$TM, p::$TP, T, n::Int) + return ManifoldsBase.allocate_coordinates(M, p.$pfield, T, n) + end + + function ManifoldsBase.angle(M::$TM, p::$TP, X::$TV, Y::$TV) + return angle(M, p.$pfield, X.$vfield, Y.$vfield) + end + + function ManifoldsBase.check_point(M::$TM, p::$TP; kwargs...) + return check_point(M, p.$pfield; kwargs...) + end + + function ManifoldsBase.check_vector(M::$TM, p::$TP, X::$TV; kwargs...) + return check_vector(M, p.$pfield, X.$vfield; kwargs...) + end + + function ManifoldsBase.distance(M::$TM, p::$TP, q::$TP) + return distance(M, p.$pfield, q.$pfield) + end + + function ManifoldsBase.embed!(M::$TM, q::$TP, p::$TP) + return embed!(M, q.$pfield, p.$pfield) + end + + function ManifoldsBase.embed!(M::$TM, Y::$TV, p::$TP, X::$TV) + return embed!(M, Y.$vfield, p.$pfield, X.$vfield) + end + + function ManifoldsBase.exp!(M::$TM, q::$TP, p::$TP, X::$TV) + exp!(M, q.$pfield, p.$pfield, X.$vfield) + return q + end + + function ManifoldsBase.inner(M::$TM, p::$TP, X::$TV, Y::$TV) + return inner(M, p.$pfield, X.$vfield, Y.$vfield) + end + + function ManifoldsBase.inverse_retract!( + M::$TM, + X::$TV, + p::$TP, + q::$TP, + m::LogarithmicInverseRetraction, + ) + inverse_retract!(M, X.$vfield, p.$pfield, q.$pfield, m) + return X + end + + function ManifoldsBase.isapprox(M::$TM, p::$TP, q::$TP; kwargs...) + return isapprox(M, p.$pfield, q.$pfield; kwargs...) + end + + function ManifoldsBase.isapprox(M::$TM, p::$TP, X::$TV, Y::$TV; kwargs...) + return isapprox(M, p.$pfield, X.$vfield, Y.$vfield; kwargs...) + end + + function ManifoldsBase.allocate_result(::$TM, ::typeof(log), p::$TP, ::$TP) + a = allocate(p.$vfield) + return $TV(a) + end + function ManifoldsBase.allocate_result( + ::$TM, + ::typeof(inverse_retract), + p::$TP, + ::$TP, + ) + a = allocate(p.$vfield) + return $TV(a) + end + + function ManifoldsBase.log!(M::$TM, X::$TV, p::$TP, q::$TP) + log!(M, X.$vfield, p.$pfield, q.$pfield) + return X + end + + function ManifoldsBase.norm(M::$TM, p::$TP, X::$TV) + return norm(M, p.$pfield, X.$vfield) + end + + function ManifoldsBase.retract!( + M::$TM, + q::$TP, + p::$TP, + X::$TV, + m::ExponentialRetraction, + ) + retract!(M, q.$pfield, p.$pfield, X.$vfield, m) + return X + end + + function ManifoldsBase.vector_transport_along!(M::$TM, Y::$TV, p::$TP, X::$TV, c) + vector_transport_along!(M, Y.$vfield, p.$pfield, X.$vfield, c) + return Y + end + + function ManifoldsBase.zero_vector(M::$TM, p::$TP) + return $TV(zero_vector(M, p.$pfield)) + end + + function ManifoldsBase.zero_vector!(M::$TM, X::$TV, p::$TP) + zero_vector!(M, X.$vfield, p.$pfield) + return X + end + end + + for BT in [ + ManifoldsBase.DISAMBIGUATION_BASIS_TYPES..., + ManifoldsBase.DISAMBIGUATION_COTANGENT_BASIS_TYPES..., + ] + push!( + block.args, + quote + function ManifoldsBase.get_coordinates!(M::$TM, Y, p::$TP, X::$TV, B::$BT) + return get_coordinates!(M, Y, p.$pfield, X.$vfield, B) + end + + function ManifoldsBase.get_vector(M::$TM, p::$TP, X, B::$BT) + return $TV(get_vector(M, p.$pfield, X, B)) + end + + function ManifoldsBase.get_vector!(M::$TM, Y::$TV, p::$TP, X, B::$BT) + return get_vector!(M, Y.$vfield, p.$pfield, X, B) + end + end, + ) + end + + for VTM in [ParallelTransport, VECTOR_TRANSPORT_DISAMBIGUATION...] + push!( + block.args, + quote + function ManifoldsBase.vector_transport_direction!( + M::$TM, + Y::$TV, + p::$TP, + X::$TV, + d::$TV, + m::$VTM, + ) + vector_transport_direction!( + M, + Y.$vfield, + p.$pfield, + X.$vfield, + d.$vfield, + m, + ) + return Y + end + function ManifoldsBase.vector_transport_to!( + M::$TM, + Y::$TV, + p::$TP, + X::$TV, + q::$TP, + m::$VTM, + ) + vector_transport_to!(M, Y.$vfield, p.$pfield, X.$vfield, q.$pfield, m) + return Y + end + end, + ) + end + + return esc(block) +end + + + +@doc raw""" + manifold_vector_forwards(T, field::Symbol) + manifold_vector_forwards(T, Twhere, field::Symbol) + +Introduce basic fallbacks for type `T` that represents vectors from a vector bundle for a +manifold. `Twhere` is put into `where` clause of each method. Fallbacks work by forwarding +to field passed as `field`. + +List of forwarded functions: +* basic arithmetic (`*`, `/`, `\`, `+`, `-`), +* all things from [`@manifold_element_forwards`](@ref), +* broadcasting support. + +# example + + @eval @manifold_vector_forwards ValidationFibreVector{TType} TType value +""" +macro manifold_vector_forwards(T, field::Symbol) + return esc(quote + @manifold_vector_forwards ($T) _ ($field) + end) +end +macro manifold_vector_forwards(T, Twhere, field::Symbol) + return esc( + quote + Base.:*(X::$T, s::Number) where {$Twhere} = $T(X.$field * s) + Base.:*(s::Number, X::$T) where {$Twhere} = $T(s * X.$field) + Base.:/(X::$T, s::Number) where {$Twhere} = $T(X.$field / s) + Base.:\(s::Number, X::$T) where {$Twhere} = $T(s \ X.$field) + Base.:+(X::$T, Y::$T) where {$Twhere} = $T(X.$field + Y.$field) + Base.:-(X::$T, Y::$T) where {$Twhere} = $T(X.$field - Y.$field) + Base.:-(X::$T) where {$Twhere} = $T(-X.$field) + Base.:+(X::$T) where {$Twhere} = $T(X.$field) + Base.zero(X::$T) where {$Twhere} = $T(zero(X.$field)) + + @eval @manifold_element_forwards $T $Twhere $field + + Base.axes(p::$T) where {$Twhere} = axes(p.$field) + + function Broadcast.BroadcastStyle(::Type{<:$T}) where {$Twhere} + return Broadcast.Style{$T}() + end + function Broadcast.BroadcastStyle( + ::Broadcast.AbstractArrayStyle{0}, + b::Broadcast.Style{$T}, + ) where {$Twhere} + return b + end + + function Broadcast.instantiate( + bc::Broadcast.Broadcasted{Broadcast.Style{$T},Nothing}, + ) where {$Twhere} + return bc + end + function Broadcast.instantiate( + bc::Broadcast.Broadcasted{Broadcast.Style{$T}}, + ) where {$Twhere} + Broadcast.check_broadcast_axes(bc.axes, bc.args...) + return bc + end + + Broadcast.broadcastable(X::$T) where {$Twhere} = X + + @inline function Base.copy( + bc::Broadcast.Broadcasted{Broadcast.Style{$T}}, + ) where {$Twhere} + return $T(Broadcast._broadcast_getindex(bc, 1)) + end + + Base.@propagate_inbounds function Broadcast._broadcast_getindex( + X::$T, + I, + ) where {$Twhere} + return X.$field + end + + @inline function Base.copyto!( + dest::$T, + bc::Broadcast.Broadcasted{Broadcast.Style{$T}}, + ) where {$Twhere} + axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) + # Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match + if bc.f === identity && bc.args isa Tuple{$T} # only a single input argument to broadcast! + A = bc.args[1] + if axes(dest) == axes(A) + return copyto!(dest, A) + end + end + bc′ = Broadcast.preprocess(dest, bc) + # Performance may vary depending on whether `@inbounds` is placed outside the + # for loop or not. (cf. https://github.com/JuliaLang/julia/issues/38086) + copyto!(dest.$field, bc′[1]) + return dest + end + end, + ) +end diff --git a/src/vector_spaces.jl b/src/vector_spaces.jl index 50861883..c00a498e 100644 --- a/src/vector_spaces.jl +++ b/src/vector_spaces.jl @@ -75,6 +75,9 @@ While a [`AbstractManifold`](@ref) does not necessarily require this type, for e implemented for `Vector`s or `Matrix` type elements, this type can be used for more complicated representations, semantic verification, or even dispatch for different representations of tangent vectors and their types on a manifold. + +You may use macro [`@manifold_vector_forwards`](@ref) to introduce commonly used method +definitions for your subtype of `AbstractFibreVector`. """ abstract type AbstractFibreVector{TType<:VectorSpaceType} end diff --git a/test/default_manifold.jl b/test/default_manifold.jl index 6e7f01b8..f0477d4b 100644 --- a/test/default_manifold.jl +++ b/test/default_manifold.jl @@ -1,5 +1,18 @@ using ManifoldsBase - +using ManifoldsBase: + @manifold_element_forwards, @manifold_vector_forwards, @default_manifold_fallbacks +import ManifoldsBase: + number_eltype, + check_point, + distance, + embed!, + exp!, + inner, + isapprox, + log!, + retract!, + inverse_retract! +import Base: angle, convert using LinearAlgebra using DoubleFloats using ForwardDiff @@ -9,6 +22,7 @@ using Test struct CustomDefinedRetraction <: ManifoldsBase.AbstractRetractionMethod end struct CustomUndefinedRetraction <: ManifoldsBase.AbstractRetractionMethod end +struct CustomDefinedInverseRetraction <: ManifoldsBase.AbstractInverseRetractionMethod end function ManifoldsBase.injectivity_radius( ::ManifoldsBase.DefaultManifold, @@ -16,6 +30,25 @@ function ManifoldsBase.injectivity_radius( ) return 10.0 end +function ManifoldsBase.retract!( + ::ManifoldsBase.DefaultManifold, + q, + p, + X, + ::CustomDefinedRetraction, +) + return (q .= p .+ X) +end +function ManifoldsBase.inverse_retract!( + ::ManifoldsBase.DefaultManifold, + X, + p, + q, + ::CustomDefinedInverseRetraction, +) + return (X .= q .- p) +end + struct MatrixVectorTransport{T} <: AbstractVector{T} m::Matrix{T} @@ -25,6 +58,25 @@ Base.getindex(x::MatrixVectorTransport, i) = x.m[:, i] Base.size(x::MatrixVectorTransport) = (size(x.m, 2),) +struct DefaultPoint{T} <: AbstractManifoldPoint + value::T +end +DefaultPoint(v::T) where {T} = DefaultPoint{T}(v) +convert(::Type{DefaultPoint{T}}, v::T) where {T} = DefaultPoint(v) + +Base.eltype(v::DefaultPoint) = eltype(v.value) + +struct DefaultTVector{T} <: TVector + value::T +end +DefaultTVector(v::T) where {T} = DefaultTVector{T}(v) + +Base.eltype(v::DefaultTVector) = eltype(v.value) + +ManifoldsBase.@manifold_element_forwards DefaultPoint value +ManifoldsBase.@manifold_vector_forwards DefaultTVector value +ManifoldsBase.@default_manifold_fallbacks ManifoldsBase.DefaultManifold DefaultPoint DefaultTVector value value + @testset "Testing Default (Euclidean)" begin M = ManifoldsBase.DefaultManifold(3) types = [ @@ -37,13 +89,14 @@ Base.size(x::MatrixVectorTransport) = (size(x.m, 2),) Vector{Double64}, MVector{3,Double64}, SizedVector{3,Double64}, + DefaultPoint{Vector{Float64}}, ] @test repr(M) == "DefaultManifold(3; field = ℝ)" @test isa(manifold_dimension(M), Integer) @test manifold_dimension(M) ≥ 0 @test base_manifold(M) == M - @test number_system(M) == ℝ + @test number_system(M) == ManifoldsBase.ℝ @test ManifoldsBase.representation_size(M) == (3,) @test injectivity_radius(M) == Inf @@ -91,15 +144,17 @@ Base.size(x::MatrixVectorTransport) = (size(x.m, 2),) retract!(M, new_pt, pts[1], tv1) @test is_point(M, new_pt) for x in pts - @test isapprox(M, zero_vector(M, x), log(M, x, x); atol = eps(eltype(x))) + @test isapprox(M, x, zero_vector(M, x), log(M, x, x); atol = eps(eltype(x))) @test isapprox( M, + x, zero_vector(M, x), inverse_retract(M, x, x); atol = eps(eltype(x)), ) @test isapprox( M, + x, zero_vector(M, x), inverse_retract(M, x, x, irm); atol = eps(eltype(x)), @@ -115,16 +170,16 @@ Base.size(x::MatrixVectorTransport) = (size(x.m, 2),) @test distance(M, pts[1], pts[2]) ≈ norm(M, pts[1], tv1) - @test mid_point(M, pts[1], pts[2]) == [0.5, 0.5, 0.0] + @test mid_point(M, pts[1], pts[2]) == convert(T, [0.5, 0.5, 0.0]) midp = allocate(pts[1]) @test mid_point!(M, midp, pts[1], pts[2]) === midp - @test midp == [0.5, 0.5, 0.0] + @test midp == convert(T, [0.5, 0.5, 0.0]) @testset "Geodesic interface test" begin @test isapprox(M, geodesic(M, pts[1], tv1)(0.0), pts[1]) @test isapprox(M, geodesic(M, pts[1], tv1)(1.0), pts[2]) @test isapprox(M, geodesic(M, pts[1], tv1, 1.0), pts[2]) - @test isapprox(M, geodesic(M, pts[1], tv1, 1.0 / 2), (pts[1] + pts[2]) / 2) + @test isapprox(M, geodesic(M, pts[1], tv1, 1.0 / 2), midp) @test isapprox(M, shortest_geodesic(M, pts[1], pts[2])(0.0), pts[1]) @test isapprox(M, shortest_geodesic(M, pts[1], pts[2])(1.0), pts[2]) @test isapprox(M, shortest_geodesic(M, pts[1], pts[2], 0.0), pts[1]) @@ -133,14 +188,14 @@ Base.size(x::MatrixVectorTransport) = (size(x.m, 2),) isapprox.( Ref(M), geodesic(M, pts[1], tv1, [0.0, 1.0 / 2, 1.0]), - [pts[1], (pts[1] + pts[2]) / 2, pts[2]], + [pts[1], midp, pts[2]], ), ) @test all( isapprox.( Ref(M), shortest_geodesic(M, pts[1], pts[2], [0.0, 1.0 / 2, 1.0]), - [pts[1], (pts[1] + pts[2]) / 2, pts[2]], + [pts[1], midp, pts[2]], ), ) end @@ -169,7 +224,7 @@ Base.size(x::MatrixVectorTransport) = (size(x.m, 2),) @test a == b @test X == Y @test Z == X - @test a == vec(X) + @test a == ((T <: DefaultPoint) ? vec(X.value) : vec(X)) end @testset "broadcasted linear algebra in tangent space" begin @@ -178,7 +233,7 @@ Base.size(x::MatrixVectorTransport) = (size(x.m, 2),) @test isapprox(M, pts[1], -tv1, .-tv1) v = similar(tv1) v .= 2 .* tv1 .+ tv1 - @test v ≈ 3 * tv1 + @test isapprox(M, pts[1], v, 3 * tv1) end @testset "project test" begin @@ -227,8 +282,14 @@ Base.size(x::MatrixVectorTransport) = (size(x.m, 2),) vector_transport_along!(M, v1t5, pts[1], v1, c) @test isapprox(M, pts[1], v1, v1t5) # along a custom type of points - T = eltype(pts[1]) - c2 = MatrixVectorTransport{T}(reshape(pts[1], length(pts[1]), 1)) + if T <: DefaultPoint + S = eltype(pts[1].value) + mat = reshape(pts[1].value, length(pts[1].value), 1) + else + S = eltype(pts[1]) + mat = reshape(pts[1], length(pts[1]), 1) + end + c2 = MatrixVectorTransport{S}(mat) v1t4c2 = vector_transport_along(M, pts[1], v1, c2) @test isapprox(M, pts[1], v1, v1t4c2) v1t5c2 = allocate(v1) @@ -253,7 +314,12 @@ Base.size(x::MatrixVectorTransport) = (size(x.m, 2),) ) == v2 # along is also the identity - c = [0.5 * (pts[1] + pts[2]), pts[2], 0.5 * (pts[2] + pts[3]), pts[3]] + c = [ + mid_point(M, pts[1], pts[2]), + pts[2], + mid_point(M, pts[2], pts[3]), + pts[3], + ] @test vector_transport_along(M, pts[1], v2, c, SchildsLadderTransport()) == v2 @test vector_transport_along(M, pts[1], v2, c, PoleLadderTransport()) == v2 @@ -262,9 +328,9 @@ Base.size(x::MatrixVectorTransport) = (size(x.m, 2),) p = allocate(pts[1]) ManifoldsBase.pole_ladder!(M, p, pts[1], pts[2], pts[3]) # -log_p3 p == log_p1 p2 - @test isapprox(M, -log(M, pts[3], p), log(M, pts[1], pts[2])) + @test isapprox(M, pts[3], -log(M, pts[3], p), log(M, pts[1], pts[2])) ManifoldsBase.schilds_ladder!(M, p, pts[1], pts[2], pts[3]) - @test isapprox(M, log(M, pts[3], p), log(M, pts[1], pts[2])) + @test isapprox(M, pts[3], log(M, pts[3], p), log(M, pts[1], pts[2])) @test repr(ParallelTransport()) == "ParallelTransport()" @test repr(ScaledVectorTransport(ParallelTransport())) == @@ -306,24 +372,48 @@ Base.size(x::MatrixVectorTransport) = (size(x.m, 2),) @test isapprox(M, fill(0.5), mid_point(M, p1, p2)) end - @testset "Retracion" begin + @testset "Retraction" begin a = NLsolveInverseRetraction(ExponentialRetraction()) @test a.retraction isa ExponentialRetraction end @testset "copy of points and vectors" begin - M = DefaultManifold(2) - p = [2.0, 3.0] - q = similar(p) - copyto!(M, q, p) - @test p == q - r = copy(M, p) - @test r == p - X = [4.0, 5.0] - Y = similar(X) - copyto!(M, Y, p, X) - @test Y == X - Z = copy(M, p, X) - @test Z == X + M = ManifoldsBase.DefaultManifold(2) + for (p, X) in ( + ([2.0, 3.0], [4.0, 5.0]), + (DefaultPoint([2.0, 3.0]), DefaultTVector([4.0, 5.0])), + ) + q = similar(p) + copyto!(M, q, p) + @test p == q + r = copy(M, p) + @test r == p + Y = similar(X) + copyto!(M, Y, p, X) + @test Y == X + Z = copy(M, p, X) + @test Z == X + end + + p1 = DefaultPoint([2.0, 3.0]) + p2 = copy(p1) + @test (p1 == p2) && (p1 !== p2) + end + @testset "further vector and point automatic forwards" begin + M = ManifoldsBase.DefaultManifold(3) + p = DefaultPoint([1.0, 0.0, 0.0]) + q = DefaultPoint([0.0, 0.0, 0.0]) + X = DefaultTVector([0.0, 1.0, 0.0]) + Y = DefaultTVector([1.0, 0.0, 0.0]) + @test angle(M, p, X, Y) ≈ π / 2 + @test inverse_retract(M, p, q, LogarithmicInverseRetraction()) == -Y + @test retract(M, q, Y, ExponentialRetraction()) == p + # Dispatch on custom + @test_broken inverse_retract(M, p, q, CustomDefinedInverseRetraction()) == -Y + @test_broken retract(M, q, Y, CustomDefinedRetraction()) == p + @test 2.0 \ X == DefaultTVector(2.0 \ X.value) + @test X + Y == DefaultTVector(X.value + Y.value) + @test +X == X + @test (Y .= X) === Y end end diff --git a/test/validation_manifold.jl b/test/validation_manifold.jl index d74d7c02..c081a13e 100644 --- a/test/validation_manifold.jl +++ b/test/validation_manifold.jl @@ -73,6 +73,10 @@ end @test isapprox(A, (a - b), T(v - w)) @test isapprox(A, -b, T(-w)) @test isapprox(A, 2 * a, T(2 .* v)) + @test isapprox(A, 2 .* a .+ b, T(2 .* v .+ w)) + c = similar(a) + c .= a .+ b + @test isapprox(A, c, a .+ b) end end @testset "AbstractManifold functions" begin