Skip to content

Commit

Permalink
Merge pull request #15 from JuliaNLSolvers/mbaran/check-error-type
Browse files Browse the repository at this point in the history
check_manifold_point and check_tangent_vector throw right error types
  • Loading branch information
mateuszbaran authored Nov 26, 2019
2 parents 9250e24 + c1c04f8 commit 9c2f41c
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 112 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ManifoldsBase"
uuid = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb"
authors = ["Seth Axen <[email protected]>", "Mateusz Baran <[email protected]>", "Ronny Bergmann <[email protected]>", "Antoine Levitt <[email protected]>"]
version = "0.1.0"
version = "0.2.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
99 changes: 53 additions & 46 deletions src/ArrayManifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,106 +101,113 @@ array_value(v::ArrayCoTVector) = v.value


function isapprox(M::ArrayManifold, x, y; kwargs...)
is_manifold_point(M, x; kwargs...)
is_manifold_point(M, y; kwargs...)
is_manifold_point(M, x, true; kwargs...)
is_manifold_point(M, y, true; kwargs...)
return isapprox(M.manifold, array_value(x), array_value(y); kwargs...)
end

function isapprox(M::ArrayManifold, x, v, w; kwargs...)
is_manifold_point(M, x; kwargs...)
is_tangent_vector(M, x, v; kwargs...)
is_tangent_vector(M, x, w; kwargs...)
is_manifold_point(M, x, true; kwargs...)
is_tangent_vector(M, x, v, true; kwargs...)
is_tangent_vector(M, x, w, true; kwargs...)
return isapprox(M.manifold, array_value(x), array_value(v), array_value(w); kwargs...)
end

function project_tangent!(M::ArrayManifold, w, x, v; kwargs...)
is_manifold_point(M, x; kwargs...)
is_manifold_point(M, x, true; kwargs...)
project_tangent!(M.manifold, w.value, array_value(x), array_value(v))
is_tangent_vector(M, x, w; kwargs...)
is_tangent_vector(M, x, w, true; kwargs...)
return w
end

function distance(M::ArrayManifold, x, y; kwargs...)
is_manifold_point(M, x; kwargs...)
is_manifold_point(M, y; kwargs...)
is_manifold_point(M, x, true; kwargs...)
is_manifold_point(M, y, true; kwargs...)
return distance(M.manifold, array_value(x), array_value(y))
end

function inner(M::ArrayManifold, x, v, w; kwargs...)
is_manifold_point(M, x; kwargs...)
is_tangent_vector(M, x, v; kwargs...)
is_tangent_vector(M, x, w; kwargs...)
is_manifold_point(M, x, true; kwargs...)
is_tangent_vector(M, x, v, true; kwargs...)
is_tangent_vector(M, x, w, true; kwargs...)
return inner(M.manifold, array_value(x), array_value(v), array_value(w))
end

function exp(M::ArrayManifold, x, v; kwargs...)
is_manifold_point(M, x; kwargs...)
is_tangent_vector(M, x, v; kwargs...)
is_manifold_point(M, x, true; kwargs...)
is_tangent_vector(M, x, v, true; kwargs...)
y = ArrayMPoint(exp(M.manifold, array_value(x), array_value(v)))
is_manifold_point(M, y; kwargs...)
is_manifold_point(M, y, true; kwargs...)
return y
end

function exp!(M::ArrayManifold, y, x, v; kwargs...)
is_manifold_point(M, x; kwargs...)
is_tangent_vector(M, x, v; kwargs...)
is_manifold_point(M, x, true; kwargs...)
is_tangent_vector(M, x, v, true; kwargs...)
exp!(M.manifold, array_value(y), array_value(x), array_value(v))
is_manifold_point(M, y; kwargs...)
is_manifold_point(M, y, true; kwargs...)
return y
end

function log(M::ArrayManifold, x, y; kwargs...)
is_manifold_point(M, x; kwargs...)
is_manifold_point(M, y; kwargs...)
is_manifold_point(M, x, true; kwargs...)
is_manifold_point(M, y, true; kwargs...)
v = ArrayTVector(log(M.manifold, array_value(x), array_value(y)))
is_tangent_vector(M, x, v; kwargs...)
is_tangent_vector(M, x, v, true; kwargs...)
return v
end

function log!(M::ArrayManifold, v, x, y; kwargs...)
is_manifold_point(M, x; kwargs...)
is_manifold_point(M, y; kwargs...)
is_manifold_point(M, x, true; kwargs...)
is_manifold_point(M, y, true; kwargs...)
log!(M.manifold, array_value(v), array_value(x), array_value(y))
is_tangent_vector(M, x, v; kwargs...)
is_tangent_vector(M, x, v, true; kwargs...)
return v
end

function zero_tangent_vector!(M::ArrayManifold, v, x; kwargs...)
is_manifold_point(M, x; kwargs...)
is_manifold_point(M, x, true; kwargs...)
zero_tangent_vector!(M.manifold, array_value(v), array_value(x); kwargs...)
is_tangent_vector(M, x, v; kwargs...)
is_tangent_vector(M, x, v, true; kwargs...)
return v
end

function zero_tangent_vector(M::ArrayManifold, x; kwargs...)
is_manifold_point(M, x; kwargs...)
is_manifold_point(M, x, true; kwargs...)
w = zero_tangent_vector(M.manifold, array_value(x))
is_tangent_vector(M, x, w; kwargs...)
is_tangent_vector(M, x, w, true; kwargs...)
return w
end

function vector_transport_to!(M::ArrayManifold, vto, x, v, y, m::AbstractVectorTransportMethod)
return vector_transport_to!(M.manifold,
array_value(vto),
array_value(x),
array_value(v),
array_value(y),
m)
function vector_transport_to!(M::ArrayManifold, vto, x, v, y, m::AbstractVectorTransportMethod; kwargs...)
is_manifold_point(M, y, true; kwargs...)
is_tangent_vector(M, x, v, true; kwargs...)
vector_transport_to!(M.manifold,
array_value(vto),
array_value(x),
array_value(v),
array_value(y),
m)
is_tangent_vector(M, y, vto, true; kwargs...)
return vto
end

function vector_transport_along!(M::ArrayManifold, vto, x, v, c, m::AbstractVectorTransportMethod)
return vector_transport_along!(M.manifold,
array_value(vto),
array_value(x),
array_value(v),
c,
m)
function vector_transport_along!(M::ArrayManifold, vto, x, v, c, m::AbstractVectorTransportMethod; kwargs...)
is_tangent_vector(M, x, v, true; kwargs...)
vector_transport_along!(M.manifold,
array_value(vto),
array_value(x),
array_value(v),
c,
m)
is_tangent_vector(M, c(1), vto, true; kwargs...)
return vto
end

function is_manifold_point(M::ArrayManifold, x::MPoint; kwargs...)
return is_manifold_point(M.manifold, array_value(x); kwargs...)
function check_manifold_point(M::ArrayManifold, x::MPoint; kwargs...)
return check_manifold_point(M.manifold, array_value(x); kwargs...)
end

function is_tangent_vector(M::ArrayManifold, x::MPoint, v::TVector; kwargs...)
return is_tangent_vector(M.manifold, array_value(x), array_value(v); kwargs...)
function check_tangent_vector(M::ArrayManifold, x::MPoint, v::TVector; kwargs...)
return check_tangent_vector(M.manifold, array_value(x), array_value(v); kwargs...)
end
106 changes: 46 additions & 60 deletions src/ManifoldsBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -566,98 +566,84 @@ function similar_result(M::Manifold, f, x...)
end

"""
manifold_point_error(M::Manifold, x; kwargs...)
check_manifold_point(M::Manifold, x; kwargs...)
Return `nothing` when `x` is a point on manifold `M`.
Otherwise, return a string with description why the point does not belong
to manifold `M`.
By default, `manifold_point_error` returns nothing for points not deriving
from the [`MPoint`](@ref) type.
By default, `check_manifold_point` returns `nothing`, i.e. if no checks are implmented,
the assumption is to be optimistic for point not deriving from the [`MPoint`](@ref) type.
"""
function manifold_point_error(M::Manifold, x; kwargs...)
function check_manifold_point(M::Manifold, x; kwargs...)
return nothing
end

function manifold_point_error(M::Manifold, x::MPoint; kwargs...)
error("manifold_point_error not implemented for manifold $(typeof(M)) and point $(typeof(x)).")
function check_manifold_point(M::Manifold, x::MPoint; kwargs...)
error("check_manifold_point not implemented for manifold $(typeof(M)) and point $(typeof(x)).")
end

"""
is_manifold_point(M,x)
is_manifold_point(M, x, throw_error = false; kwargs...)
check, whether `x` is a valid point on the [`Manifold`](@ref) `M`.
Returns either `true` or `false`.
The default is to return `true`, i.e. if no checks are implmented,
the assumption is to be optimistic.
"""
function is_manifold_point(M::Manifold, x; kwargs...)
return manifold_point_error(M, x; kwargs...) === nothing
end

"""
check_manifold_point(M,x)
check, whether `x` is a valid point on the [`Manifold`](@ref) `M`. If it is not,
an error is thrown.
The default is to return `true`, i.e. if no checks are implmented,
the assumption is to be optimistic.
"""
function check_manifold_point(M::Manifold, x; kwargs...)
mpe = manifold_point_error(M, x; kwargs...)
if mpe !== nothing
error(mpe)
If `throw_error` is false, the function returns either `true` or `false`.
If `throw_error` if true, the function either returns `true` or throws an error.
By default the function calls [`check_manifold_point`](@ref)`(M, x; kwargs...)`
and checks whether the returned value is `nothing` or an error.
"""
function is_manifold_point(M::Manifold, x, throw_error = false; kwargs...)
mpe = check_manifold_point(M, x; kwargs...)
if throw_error
if mpe !== nothing
throw(mpe)
end
return true
else
return mpe === nothing
end
end


"""
tangent_vector_error(M::Manifold, x, v; kwargs...)
check_tangent_vector(M::Manifold, x, v; kwargs...)
check, whether `v` is a valid tangent vector in the tangent plane of `x` on the
[`Manifold`](@ref) `M`. An implementation should first check
[`manifold_point_error`](@ref)`(M,x)` and then validate `v`. If it is not a tangent
vector error string should be returned.
[`manifold_point_error`](@ref)`(M, x; kwargs...)` and then validate `v`.
If it is not a tangent vector error string should be returned.
The default is to return `nothing`, i.e. if no checks are implmented,
the assumption is to be optimistic.
By default, `check_tangent_vector` returns `nothing`, i.e. if no checks are implmented,
the assumption is to be optimistic for tangent vectors not deriving from the [`TVector`](@ref) type.
"""
function tangent_vector_error(M::Manifold, x, v; kwargs...)
function check_tangent_vector(M::Manifold, x, v; kwargs...)
return nothing
end

function tangent_vector_error(M::Manifold, x::MPoint, v::TVector; kwargs...)
error("tangent_vector_error not implemented for manifold $(typeof(M)), point $(typeof(x)) and vector $(typeof(v)).")
function check_tangent_vector(M::Manifold, x::MPoint, v::TVector; kwargs...)
error("check_tangent_vector not implemented for manifold $(typeof(M)), point $(typeof(x)) and vector $(typeof(v)).")
end

"""
is_tangent_vector(M, x, v; kwargs...)
is_tangent_vector(M, x, v, throw_error = false; kwargs...)
check, whether `v` is a valid tangent vector at point `x` on
the [`Manifold`](@ref) `M`. Returns either `true` or `false`.
The default is to return `true`, i.e. if no checks are implmented,
the assumption is to be optimistic.
"""
function is_tangent_vector(M::Manifold, x, v; kwargs...)
return tangent_vector_error(M, x, v; kwargs...) === nothing
end

"""
check_tangent_vector(M, x, v; kwargs...)
check, whether `v` is a valid tangent vector in the tangent plane of `x` on the
[`Manifold`](@ref) `M`. An implementation should first check
[`manifold_point_error`](@ref)`(M,x)` and then validate `v`. If it is not a tangent
vector an error is thrown.
The default is to return `nothing`, i.e. if no checks are implmented,
the assumption is to be optimistic.
"""
function check_tangent_vector(M::Manifold, x, v; kwargs...)
tve = tangent_vector_error(M, x, v; kwargs...)
if tve !== nothing
error(tve)
the [`Manifold`](@ref) `M`.
If `throw_error` is false, the function returns either `true` or `false`.
If `throw_error` if true, the function either returns `true` or throws an error.
By default the function calls [`check_tangent_vector`](@ref)`(M, x, v; kwargs...)`
and checks whether the returned value is `nothing` or an error.
"""
function is_tangent_vector(M::Manifold, x, v, throw_error = false; kwargs...)
tve = check_tangent_vector(M, x, v; kwargs...)
if throw_error
if tve !== nothing
throw(tve)
end
return true
else
return tve === nothing
end
end

Expand Down
34 changes: 34 additions & 0 deletions test/domain_errors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
using ManifoldsBase
using Test

struct ErrorTestManifold <: Manifold end

function ManifoldsBase.check_manifold_point(::ErrorTestManifold, x)
if any(u -> u < 0, x)
return DomainError(x, "<0")
end
return nothing
end
function ManifoldsBase.check_tangent_vector(M::ErrorTestManifold, x, v)
mpe = check_manifold_point(M, x)
mpe === nothing || return mpe
if any(u -> u < 0, v)
return DomainError(v, "<0")
end
return nothing
end

@testset "Domain errors" begin
M = ErrorTestManifold()
@test isa(check_manifold_point(M, [-1, 1]), DomainError)
@test check_manifold_point(M, [1, 1]) === nothing
@test !is_manifold_point(M, [-1, 1])
@test is_manifold_point(M, [1, 1])
@test_throws DomainError is_manifold_point(M, [-1, 1], true)

@test isa(check_tangent_vector(M, [1, 1], [-1, 1]), DomainError)
@test check_tangent_vector(M, [1, 1], [1, 1]) === nothing
@test !is_tangent_vector(M, [1, 1], [-1, 1])
@test is_tangent_vector(M, [1, 1], [1, 1])
@test_throws DomainError is_tangent_vector(M, [1, 1], [-1, 1], true)
end
10 changes: 5 additions & 5 deletions test/empty_manifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,14 @@ struct NonCoTVector <: CoTVector end

@test_throws ErrorException zero_tangent_vector!(m, [0], [0])
@test_throws ErrorException zero_tangent_vector(m, [0])
@test manifold_point_error(m, [0]) === nothing
@test_throws ErrorException manifold_point_error(m,p)

@test check_manifold_point(m, [0]) === nothing
@test_throws ErrorException check_manifold_point(m,p)
@test is_manifold_point(m, [0])
@test check_manifold_point(m, [0]) == nothing

@test tangent_vector_error(m, [0], [0]) === nothing
@test_throws ErrorException tangent_vector_error(m,p,v)
@test check_tangent_vector(m, [0], [0]) === nothing
@test_throws ErrorException check_tangent_vector(m,p,v)
@test is_tangent_vector(m, [0], [0])
@test check_tangent_vector(m, [0], [0]) == nothing

Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
include("empty_manifold.jl")
include("default_manifold.jl")
include("array_manifold.jl")
include("domain_errors.jl")

2 comments on commit 9c2f41c

@mateuszbaran
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/5881

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.0 -m "<description of version>" 9c2f41c8a611e2fed2aa99a742684efa1a232c43
git push origin v0.2.0

Please sign in to comment.