Skip to content

Commit

Permalink
Introduce fallbacks for AbstractManifoldPoints and Vectors (#89)
Browse files Browse the repository at this point in the history
* Introduce default fallbacks for manifold point and tangent vector types, especially broadcasting to be passed on to the .value field
* forwarding macros

Co-authored-by: Mateusz Baran <[email protected]>
  • Loading branch information
kellertuer and mateuszbaran authored Dec 8, 2021
1 parent 7307caf commit 1de2668
Show file tree
Hide file tree
Showing 9 changed files with 476 additions and 67 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ManifoldsBase"
uuid = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb"
authors = ["Seth Axen <[email protected]>", "Mateusz Baran <[email protected]>", "Ronny Bergmann <[email protected]>", "Antoine Levitt <[email protected]>"]
version = "0.12.9"
version = "0.12.10"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
8 changes: 5 additions & 3 deletions src/ManifoldsBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ import Base:
show,
+,
-,
*
*,
==
import LinearAlgebra: dot, norm, det, cross, I, UniformScaling, Diagonal

import Markdown: @doc_str
Expand Down Expand Up @@ -62,6 +63,7 @@ allocate(a::AbstractArray{<:AbstractArray}, T::Type) = map(t -> allocate(t, T),
allocate(a::NTuple{N,AbstractArray} where {N}) = map(allocate, a)
allocate(a::NTuple{N,AbstractArray} where {N}, T::Type) = map(t -> allocate(t, T), a)


"""
allocate_result(M::AbstractManifold, f, x...)
Expand Down Expand Up @@ -100,7 +102,6 @@ Compute the angle between tangent vectors `X` and `Y` at point `p` from the
function angle(M::AbstractManifold, p, X, Y)
return acos(real(inner(M, p, X, Y)) / norm(M, p, X) / norm(M, p, Y))
end

"""
base_manifold(M::AbstractManifold, depth = Val(-1))
Expand Down Expand Up @@ -531,7 +532,7 @@ The size of an array representing a point on [`AbstractManifold`](@ref) `M`.
Returns `nothing` by default indicating that points are not represented using an
`AbstractArray`.
"""
function representation_size(M::AbstractManifold)
function representation_size(::AbstractManifold)
return nothing
end

Expand Down Expand Up @@ -569,6 +570,7 @@ include("vector_transport.jl")
include("DecoratorManifold.jl")
include("bases.jl")
include("vector_spaces.jl")
include("point_vector_fallbacks.jl")
include("ValidationManifold.jl")
include("EmbeddedManifold.jl")
include("DefaultManifold.jl")
Expand Down
32 changes: 2 additions & 30 deletions src/ValidationManifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,26 +57,9 @@ This distinguished the value from [`ValidationMPoint`](@ref)s vectors of other t
"""
const ValidationCoTVector = ValidationFibreVector{CotangentSpaceType}

@eval @manifold_vector_forwards ValidationFibreVector{TType} TType value

function (+)(X::ValidationFibreVector{TType}, Y::ValidationFibreVector{TType}) where {TType}
return ValidationFibreVector{TType}(X.value + Y.value)
end
function (-)(X::ValidationFibreVector{TType}, Y::ValidationFibreVector{TType}) where {TType}
return ValidationFibreVector{TType}(X.value - Y.value)
end
(-)(X::ValidationFibreVector{TType}) where {TType} = ValidationFibreVector{TType}(-X.value)
function (*)(a::Number, X::ValidationFibreVector{TType}) where {TType}
return ValidationFibreVector{TType}(a * X.value)
end

allocate(p::ValidationMPoint) = ValidationMPoint(allocate(p.value))
allocate(p::ValidationMPoint, ::Type{T}) where {T} = ValidationMPoint(allocate(p.value, T))
function allocate(X::ValidationFibreVector{TType}) where {TType}
return ValidationFibreVector{TType}(allocate(X.value))
end
function allocate(X::ValidationFibreVector{TType}, ::Type{T}) where {TType,T}
return ValidationFibreVector{TType}(allocate(X.value, T))
end
@eval @manifold_element_forwards ValidationMPoint value

"""
array_value(p)
Expand Down Expand Up @@ -422,9 +405,7 @@ function mid_point!(M::ValidationManifold, q, p1, p2; kwargs...)
end

number_eltype(::Type{ValidationMPoint{V}}) where {V} = number_eltype(V)
number_eltype(p::ValidationMPoint) = number_eltype(p.value)
number_eltype(::Type{ValidationFibreVector{TType,V}}) where {TType,V} = number_eltype(V)
number_eltype(X::ValidationFibreVector) = number_eltype(X.value)

function project!(M::ValidationManifold, Y, p, X; kwargs...)
is_point(M, p, true; kwargs...)
Expand All @@ -433,15 +414,6 @@ function project!(M::ValidationManifold, Y, p, X; kwargs...)
return Y
end

similar(p::ValidationMPoint) = ValidationMPoint(similar(p.value))
similar(p::ValidationMPoint, ::Type{T}) where {T} = ValidationMPoint(similar(p.value, T))
function similar(X::ValidationFibreVector{TType}) where {TType}
return ValidationFibreVector{TType}(similar(X.value))
end
function similar(X::ValidationFibreVector{TType}, ::Type{T}) where {TType,T}
return ValidationFibreVector{TType}(similar(X.value, T))
end

function vector_transport_along!(
M::ValidationManifold,
Y,
Expand Down
11 changes: 9 additions & 2 deletions src/bases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,13 @@ const DISAMBIGUATION_COTANGENT_BASIS_TYPES = [
DefaultOrthogonalBasis{<:Any,CotangentSpaceType},
]

"""
allocate_coordinates(M::AbstractManifold, p, T, n::Int)
Allocate vector of coordinates of length `n` of type `T` of a vector at point `p`
on manifold `M`.
"""
allocate_coordinates(M::AbstractManifold, p, T, n::Int) = allocate(p, T, n)

function allocate_result(
M::AbstractManifold,
Expand All @@ -294,7 +301,7 @@ function allocate_result(
B::AbstractBasis,
)
T = allocate_result_type(M, f, (p, X))
return allocate(p, T, number_of_coordinates(M, B))
return allocate_coordinates(M, p, T, number_of_coordinates(M, B))
end

function allocate_result(
Expand All @@ -305,7 +312,7 @@ function allocate_result(
B::CachedBasis,
)
T = allocate_result_type(M, f, (p, X))
return allocate(p, T, number_of_coordinates(M, B))
return allocate_coordinates(M, p, T, number_of_coordinates(M, B))
end

@inline function allocate_result_type(
Expand Down
10 changes: 8 additions & 2 deletions src/maintypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@ abstract type AbstractManifold{𝔽} end
Type for a point on a manifold. While a [`AbstractManifold`](@ref) does not necessarily require this
type, for example when it is implemented for `Vector`s or `Matrix` type elements, this type
can be used for more complicated representations, semantic verification, or even dispatch
for different representations of points on a manifold.
can be used either
* for more complicated representations,
* semantic verification, or
* even dispatch for different representations of points on a manifold.
Since semantic verification and different representations usually might still only store a
matrix internally, it is possible to use [`@manifold_element_forwards`](@ref) and
[`@default_manifold_fallbacks`](@ref) to reduce implementation overhead.
"""
abstract type AbstractManifoldPoint end
Loading

2 comments on commit 1de2668

@kellertuer
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/50151

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.12.10 -m "<description of version>" 1de2668e19047558a58f1b860586e588b9fc2d8c
git push origin v0.12.10

Please sign in to comment.