Skip to content

Commit

Permalink
restructuring of point and tangent vector checking
Browse files Browse the repository at this point in the history
  • Loading branch information
mateuszbaran committed Nov 26, 2019
1 parent e223a17 commit 59c40b5
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 119 deletions.
99 changes: 53 additions & 46 deletions src/ArrayManifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,106 +101,113 @@ array_value(v::ArrayCoTVector) = v.value


function isapprox(M::ArrayManifold, x, y; kwargs...)
is_manifold_point(M, x; kwargs...)
is_manifold_point(M, y; kwargs...)
is_manifold_point(M, x, true; kwargs...)
is_manifold_point(M, y, true; kwargs...)
return isapprox(M.manifold, array_value(x), array_value(y); kwargs...)
end

function isapprox(M::ArrayManifold, x, v, w; kwargs...)
is_manifold_point(M, x; kwargs...)
is_tangent_vector(M, x, v; kwargs...)
is_tangent_vector(M, x, w; kwargs...)
is_manifold_point(M, x, true; kwargs...)
is_tangent_vector(M, x, v, true; kwargs...)
is_tangent_vector(M, x, w, true; kwargs...)
return isapprox(M.manifold, array_value(x), array_value(v), array_value(w); kwargs...)
end

function project_tangent!(M::ArrayManifold, w, x, v; kwargs...)
is_manifold_point(M, x; kwargs...)
is_manifold_point(M, x, true; kwargs...)
project_tangent!(M.manifold, w.value, array_value(x), array_value(v))
is_tangent_vector(M, x, w; kwargs...)
is_tangent_vector(M, x, w, true; kwargs...)
return w
end

function distance(M::ArrayManifold, x, y; kwargs...)
is_manifold_point(M, x; kwargs...)
is_manifold_point(M, y; kwargs...)
is_manifold_point(M, x, true; kwargs...)
is_manifold_point(M, y, true; kwargs...)
return distance(M.manifold, array_value(x), array_value(y))
end

function inner(M::ArrayManifold, x, v, w; kwargs...)
is_manifold_point(M, x; kwargs...)
is_tangent_vector(M, x, v; kwargs...)
is_tangent_vector(M, x, w; kwargs...)
is_manifold_point(M, x, true; kwargs...)
is_tangent_vector(M, x, v, true; kwargs...)
is_tangent_vector(M, x, w, true; kwargs...)
return inner(M.manifold, array_value(x), array_value(v), array_value(w))
end

function exp(M::ArrayManifold, x, v; kwargs...)
is_manifold_point(M, x; kwargs...)
is_tangent_vector(M, x, v; kwargs...)
is_manifold_point(M, x, true; kwargs...)
is_tangent_vector(M, x, v, true; kwargs...)
y = ArrayMPoint(exp(M.manifold, array_value(x), array_value(v)))
is_manifold_point(M, y; kwargs...)
is_manifold_point(M, y, true; kwargs...)
return y
end

function exp!(M::ArrayManifold, y, x, v; kwargs...)
is_manifold_point(M, x; kwargs...)
is_tangent_vector(M, x, v; kwargs...)
is_manifold_point(M, x, true; kwargs...)
is_tangent_vector(M, x, v, true; kwargs...)
exp!(M.manifold, array_value(y), array_value(x), array_value(v))
is_manifold_point(M, y; kwargs...)
is_manifold_point(M, y, true; kwargs...)
return y
end

function log(M::ArrayManifold, x, y; kwargs...)
is_manifold_point(M, x; kwargs...)
is_manifold_point(M, y; kwargs...)
is_manifold_point(M, x, true; kwargs...)
is_manifold_point(M, y, true; kwargs...)
v = ArrayTVector(log(M.manifold, array_value(x), array_value(y)))
is_tangent_vector(M, x, v; kwargs...)
is_tangent_vector(M, x, v, true; kwargs...)
return v
end

function log!(M::ArrayManifold, v, x, y; kwargs...)
is_manifold_point(M, x; kwargs...)
is_manifold_point(M, y; kwargs...)
is_manifold_point(M, x, true; kwargs...)
is_manifold_point(M, y, true; kwargs...)
log!(M.manifold, array_value(v), array_value(x), array_value(y))
is_tangent_vector(M, x, v; kwargs...)
is_tangent_vector(M, x, v, true; kwargs...)
return v
end

function zero_tangent_vector!(M::ArrayManifold, v, x; kwargs...)
is_manifold_point(M, x; kwargs...)
is_manifold_point(M, x, true; kwargs...)
zero_tangent_vector!(M.manifold, array_value(v), array_value(x); kwargs...)
is_tangent_vector(M, x, v; kwargs...)
is_tangent_vector(M, x, v, true; kwargs...)
return v
end

function zero_tangent_vector(M::ArrayManifold, x; kwargs...)
is_manifold_point(M, x; kwargs...)
is_manifold_point(M, x, true; kwargs...)
w = zero_tangent_vector(M.manifold, array_value(x))
is_tangent_vector(M, x, w; kwargs...)
is_tangent_vector(M, x, w, true; kwargs...)
return w
end

function vector_transport_to!(M::ArrayManifold, vto, x, v, y, m::AbstractVectorTransportMethod)
return vector_transport_to!(M.manifold,
array_value(vto),
array_value(x),
array_value(v),
array_value(y),
m)
function vector_transport_to!(M::ArrayManifold, vto, x, v, y, m::AbstractVectorTransportMethod; kwargs...)
is_manifold_point(M, y, true; kwargs...)
is_tangent_vector(M, x, v, true; kwargs...)
vector_transport_to!(M.manifold,
array_value(vto),
array_value(x),
array_value(v),
array_value(y),
m)
is_tangent_vector(M, y, vto, true; kwargs...)
return vto
end

function vector_transport_along!(M::ArrayManifold, vto, x, v, c, m::AbstractVectorTransportMethod)
return vector_transport_along!(M.manifold,
array_value(vto),
array_value(x),
array_value(v),
c,
m)
function vector_transport_along!(M::ArrayManifold, vto, x, v, c, m::AbstractVectorTransportMethod; kwargs...)
is_tangent_vector(M, x, v, true; kwargs...)
vector_transport_along!(M.manifold,
array_value(vto),
array_value(x),
array_value(v),
c,
m)
is_tangent_vector(M, c(1), vto, true; kwargs...)
return vto
end

function manifold_point_error(M::ArrayManifold, x::MPoint; kwargs...)
return manifold_point_error(M.manifold, array_value(x); kwargs...)
function check_manifold_point(M::ArrayManifold, x::MPoint; kwargs...)
return check_manifold_point(M.manifold, array_value(x); kwargs...)
end

function tangent_vector_error(M::ArrayManifold, x::MPoint, v::TVector; kwargs...)
return tangent_vector_error(M.manifold, array_value(x), array_value(v); kwargs...)
function check_tangent_vector(M::ArrayManifold, x::MPoint, v::TVector; kwargs...)
return check_tangent_vector(M.manifold, array_value(x), array_value(v); kwargs...)
end
106 changes: 46 additions & 60 deletions src/ManifoldsBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -566,98 +566,84 @@ function similar_result(M::Manifold, f, x...)
end

"""
manifold_point_error(M::Manifold, x; kwargs...)
check_manifold_point(M::Manifold, x; kwargs...)
Return `nothing` when `x` is a point on manifold `M`.
Otherwise, return a string with description why the point does not belong
to manifold `M`.
By default, `manifold_point_error` returns nothing for points not deriving
from the [`MPoint`](@ref) type.
By default, `check_manifold_point` returns `nothing`, i.e. if no checks are implmented,
the assumption is to be optimistic for point not deriving from the [`MPoint`](@ref) type.
"""
function manifold_point_error(M::Manifold, x; kwargs...)
function check_manifold_point(M::Manifold, x; kwargs...)
return nothing
end

function manifold_point_error(M::Manifold, x::MPoint; kwargs...)
error("manifold_point_error not implemented for manifold $(typeof(M)) and point $(typeof(x)).")
function check_manifold_point(M::Manifold, x::MPoint; kwargs...)
error("check_manifold_point not implemented for manifold $(typeof(M)) and point $(typeof(x)).")
end

"""
is_manifold_point(M,x)
is_manifold_point(M, x, throw_error = false; kwargs...)
check, whether `x` is a valid point on the [`Manifold`](@ref) `M`.
Returns either `true` or `false`.
The default is to return `true`, i.e. if no checks are implmented,
the assumption is to be optimistic.
"""
function is_manifold_point(M::Manifold, x; kwargs...)
return manifold_point_error(M, x; kwargs...) === nothing
end

"""
check_manifold_point(M,x)
check, whether `x` is a valid point on the [`Manifold`](@ref) `M`. If it is not,
an error is thrown.
The default is to return `true`, i.e. if no checks are implmented,
the assumption is to be optimistic.
"""
function check_manifold_point(M::Manifold, x; kwargs...)
mpe = manifold_point_error(M, x; kwargs...)
if mpe !== nothing
throw(mpe)
If `throw_error` is false, the function returns either `true` or `false`.
If `throw_error` if true, the function either returns `true` or throws an error.
By default the function calls [`check_manifold_point`](@ref)`(M, x; kwargs...)`
and checks whether the returned value is `nothing` or an error.
"""
function is_manifold_point(M::Manifold, x, throw_error = false; kwargs...)
mpe = check_manifold_point(M, x; kwargs...)
if throw_error
if mpe !== nothing
throw(mpe)
end
return true
else
return mpe === nothing
end
end


"""
tangent_vector_error(M::Manifold, x, v; kwargs...)
check_tangent_vector(M::Manifold, x, v; kwargs...)
check, whether `v` is a valid tangent vector in the tangent plane of `x` on the
[`Manifold`](@ref) `M`. An implementation should first check
[`manifold_point_error`](@ref)`(M,x)` and then validate `v`. If it is not a tangent
vector error string should be returned.
[`manifold_point_error`](@ref)`(M, x; kwargs...)` and then validate `v`.
If it is not a tangent vector error string should be returned.
The default is to return `nothing`, i.e. if no checks are implmented,
the assumption is to be optimistic.
By default, `check_tangent_vector` returns `nothing`, i.e. if no checks are implmented,
the assumption is to be optimistic for tangent vectors not deriving from the [`TVector`](@ref) type.
"""
function tangent_vector_error(M::Manifold, x, v; kwargs...)
function check_tangent_vector(M::Manifold, x, v; kwargs...)
return nothing
end

function tangent_vector_error(M::Manifold, x::MPoint, v::TVector; kwargs...)
error("tangent_vector_error not implemented for manifold $(typeof(M)), point $(typeof(x)) and vector $(typeof(v)).")
function check_tangent_vector(M::Manifold, x::MPoint, v::TVector; kwargs...)
error("check_tangent_vector not implemented for manifold $(typeof(M)), point $(typeof(x)) and vector $(typeof(v)).")
end

"""
is_tangent_vector(M, x, v; kwargs...)
is_tangent_vector(M, x, v, throw_error = false; kwargs...)
check, whether `v` is a valid tangent vector at point `x` on
the [`Manifold`](@ref) `M`. Returns either `true` or `false`.
The default is to return `true`, i.e. if no checks are implmented,
the assumption is to be optimistic.
"""
function is_tangent_vector(M::Manifold, x, v; kwargs...)
return tangent_vector_error(M, x, v; kwargs...) === nothing
end

"""
check_tangent_vector(M, x, v; kwargs...)
check, whether `v` is a valid tangent vector in the tangent plane of `x` on the
[`Manifold`](@ref) `M`. An implementation should first check
[`manifold_point_error`](@ref)`(M,x)` and then validate `v`. If it is not a tangent
vector an error is thrown.
The default is to return `nothing`, i.e. if no checks are implmented,
the assumption is to be optimistic.
"""
function check_tangent_vector(M::Manifold, x, v; kwargs...)
tve = tangent_vector_error(M, x, v; kwargs...)
if tve !== nothing
throw(tve)
the [`Manifold`](@ref) `M`.
If `throw_error` is false, the function returns either `true` or `false`.
If `throw_error` if true, the function either returns `true` or throws an error.
By default the function calls [`check_tangent_vector`](@ref)`(M, x, v; kwargs...)`
and checks whether the returned value is `nothing` or an error.
"""
function is_tangent_vector(M::Manifold, x, v, throw_error = false; kwargs...)
tve = check_tangent_vector(M, x, v; kwargs...)
if throw_error
if tve !== nothing
throw(tve)
end
return true
else
return tve === nothing
end
end

Expand Down
18 changes: 9 additions & 9 deletions test/domain_errors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ using Test

struct ErrorTestManifold <: Manifold end

function ManifoldsBase.manifold_point_error(::ErrorTestManifold, x)
function ManifoldsBase.check_manifold_point(::ErrorTestManifold, x)
if any(u -> u < 0, x)
return DomainError(x, "<0")
end
return nothing
end
function ManifoldsBase.tangent_vector_error(M::ErrorTestManifold, x, v)
mpe = manifold_point_error(M, x)
function ManifoldsBase.check_tangent_vector(M::ErrorTestManifold, x, v)
mpe = check_manifold_point(M, x)
mpe === nothing || return mpe
if any(u -> u < 0, v)
return DomainError(v, "<0")
Expand All @@ -20,15 +20,15 @@ end

@testset "Domain errors" begin
M = ErrorTestManifold()
@test isa(manifold_point_error(M, [-1, 1]), DomainError)
@test manifold_point_error(M, [1, 1]) === nothing
@test isa(check_manifold_point(M, [-1, 1]), DomainError)
@test check_manifold_point(M, [1, 1]) === nothing
@test !is_manifold_point(M, [-1, 1])
@test is_manifold_point(M, [1, 1])
@test_throws DomainError check_manifold_point(M, [-1, 1])
@test_throws DomainError is_manifold_point(M, [-1, 1], true)

@test isa(tangent_vector_error(M, [1, 1], [-1, 1]), DomainError)
@test tangent_vector_error(M, [1, 1], [1, 1]) === nothing
@test isa(check_tangent_vector(M, [1, 1], [-1, 1]), DomainError)
@test check_tangent_vector(M, [1, 1], [1, 1]) === nothing
@test !is_tangent_vector(M, [1, 1], [-1, 1])
@test is_tangent_vector(M, [1, 1], [1, 1])
@test_throws DomainError check_tangent_vector(M, [1, 1], [-1, 1])
@test_throws DomainError is_tangent_vector(M, [1, 1], [-1, 1], true)
end
8 changes: 4 additions & 4 deletions test/empty_manifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,13 @@ struct NonCoTVector <: CoTVector end
@test_throws ErrorException zero_tangent_vector!(m, [0], [0])
@test_throws ErrorException zero_tangent_vector(m, [0])

@test manifold_point_error(m, [0]) === nothing
@test_throws ErrorException manifold_point_error(m,p)
@test check_manifold_point(m, [0]) === nothing
@test_throws ErrorException check_manifold_point(m,p)
@test is_manifold_point(m, [0])
@test check_manifold_point(m, [0]) == nothing

@test tangent_vector_error(m, [0], [0]) === nothing
@test_throws ErrorException tangent_vector_error(m,p,v)
@test check_tangent_vector(m, [0], [0]) === nothing
@test_throws ErrorException check_tangent_vector(m,p,v)
@test is_tangent_vector(m, [0], [0])
@test check_tangent_vector(m, [0], [0]) == nothing

Expand Down

0 comments on commit 59c40b5

Please sign in to comment.