Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Mar 24, 2024
1 parent feb9d0c commit bcbc649
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 136 deletions.
253 changes: 126 additions & 127 deletions lib/OptimizationManopt/src/OptimizationManopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,113 +20,112 @@ end
## gradient descent

struct GradientDescentOptimizer{
Teval <: AbstractEvaluationType,
TM <: AbstractManifold,
TLS <: Linesearch
} <: AbstractManoptOptimizer
Teval <: AbstractEvaluationType,
TM <: AbstractManifold,
TLS <: Linesearch
} <: AbstractManoptOptimizer
M::TM
stepsize::TLS
end

function GradientDescentOptimizer(M::AbstractManifold;
eval::AbstractEvaluationType = Manopt.AllocatingEvaluation(),
stepsize::Stepsize = ArmijoLinesearch(M))
eval::AbstractEvaluationType = Manopt.AllocatingEvaluation(),
stepsize::Stepsize = ArmijoLinesearch(M))
GradientDescentOptimizer{typeof(eval), typeof(M), typeof(stepsize)}(M, stepsize)
end

function call_manopt_optimizer(opt::GradientDescentOptimizer{Teval},
loss,
gradF,
x0,
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
Teval <:
AbstractEvaluationType
}
loss,
gradF,
x0,
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
Teval <:
AbstractEvaluationType
}
sckwarg = stopping_criterion_to_kwarg(stopping_criterion)
opts = gradient_descent(opt.M,
loss,
gradF,
x0;
return_state = true,
evaluation = Teval(),
stepsize = opt.stepsize,
sckwarg...)
loss,
gradF,
x0;
return_state = true,
evaluation = Teval(),
stepsize = opt.stepsize,
sckwarg...)
# we unwrap DebugOptions here
minimizer = Manopt.get_solver_result(opts)
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
:who_knows
:who_knows
end

## Nelder-Mead

struct NelderMeadOptimizer{
TM <: AbstractManifold,
} <: AbstractManoptOptimizer
TM <: AbstractManifold,
} <: AbstractManoptOptimizer
M::TM
end


function call_manopt_optimizer(opt::NelderMeadOptimizer,
loss,
gradF,
x0,
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion})
loss,
gradF,
x0,
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion})
sckwarg = stopping_criterion_to_kwarg(stopping_criterion)

opts = NelderMead(opt.M,
loss;
return_state = true,
sckwarg...)
loss;
return_state = true,
sckwarg...)
minimizer = Manopt.get_solver_result(opts)
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
:who_knows
:who_knows
end

## conjugate gradient descent

struct ConjugateGradientDescentOptimizer{Teval <: AbstractEvaluationType,
TM <: AbstractManifold, TLS <: Stepsize} <:
TM <: AbstractManifold, TLS <: Stepsize} <:
AbstractManoptOptimizer
M::TM
stepsize::TLS
end

function ConjugateGradientDescentOptimizer(M::AbstractManifold;
eval::AbstractEvaluationType = InplaceEvaluation(),
stepsize::Stepsize = ArmijoLinesearch(M))
eval::AbstractEvaluationType = InplaceEvaluation(),
stepsize::Stepsize = ArmijoLinesearch(M))
ConjugateGradientDescentOptimizer{typeof(eval), typeof(M), typeof(stepsize)}(M,
stepsize)
stepsize)
end

function call_manopt_optimizer(opt::ConjugateGradientDescentOptimizer{Teval},
loss,
gradF,
x0,
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
Teval <:
AbstractEvaluationType
}
loss,
gradF,
x0,
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
Teval <:
AbstractEvaluationType
}
sckwarg = stopping_criterion_to_kwarg(stopping_criterion)
opts = conjugate_gradient_descent(opt.M,
loss,
gradF,
x0;
return_state = true,
evaluation = Teval(),
stepsize = opt.stepsize,
sckwarg...)
loss,
gradF,
x0;
return_state = true,
evaluation = Teval(),
stepsize = opt.stepsize,
sckwarg...)
# we unwrap DebugOptions here
minimizer = Manopt.get_solver_result(opts)
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
:who_knows
:who_knows
end

## particle swarm

struct ParticleSwarmOptimizer{Teval <: AbstractEvaluationType,
TM <: AbstractManifold, Tretr <: AbstractRetractionMethod,
Tinvretr <: AbstractInverseRetractionMethod,
Tvt <: AbstractVectorTransportMethod} <:
TM <: AbstractManifold, Tretr <: AbstractRetractionMethod,
Tinvretr <: AbstractInverseRetractionMethod,
Tvt <: AbstractVectorTransportMethod} <:
AbstractManoptOptimizer
M::TM
retraction_method::Tretr
Expand All @@ -136,50 +135,50 @@ struct ParticleSwarmOptimizer{Teval <: AbstractEvaluationType,
end

function ParticleSwarmOptimizer(M::AbstractManifold;
eval::AbstractEvaluationType = InplaceEvaluation(),
population_size::Int = 100,
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
inverse_retraction_method::AbstractInverseRetractionMethod = default_inverse_retraction_method(M),
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M))
eval::AbstractEvaluationType = InplaceEvaluation(),
population_size::Int = 100,
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
inverse_retraction_method::AbstractInverseRetractionMethod = default_inverse_retraction_method(M),
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M))
ParticleSwarmOptimizer{typeof(eval), typeof(M), typeof(retraction_method),
typeof(inverse_retraction_method),
typeof(vector_transport_method)}(M,
retraction_method,
inverse_retraction_method,
vector_transport_method,
population_size)
typeof(inverse_retraction_method),
typeof(vector_transport_method)}(M,
retraction_method,
inverse_retraction_method,
vector_transport_method,
population_size)
end

function call_manopt_optimizer(opt::ParticleSwarmOptimizer{Teval},
loss,
gradF,
x0,
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
Teval <:
AbstractEvaluationType
}
loss,
gradF,
x0,
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
Teval <:
AbstractEvaluationType
}
sckwarg = stopping_criterion_to_kwarg(stopping_criterion)
initial_population = vcat([x0], [rand(opt.M) for _ in 1:(opt.population_size - 1)])
opts = particle_swarm(opt.M,
loss;
x0 = initial_population,
n = opt.population_size,
return_state = true,
retraction_method = opt.retraction_method,
inverse_retraction_method = opt.inverse_retraction_method,
vector_transport_method = opt.vector_transport_method,
sckwarg...)
loss;
x0 = initial_population,
n = opt.population_size,
return_state = true,
retraction_method = opt.retraction_method,
inverse_retraction_method = opt.inverse_retraction_method,
vector_transport_method = opt.vector_transport_method,
sckwarg...)
# we unwrap DebugOptions here
minimizer = Manopt.get_solver_result(opts)
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
:who_knows
:who_knows
end

## quasi Newton

struct QuasiNewtonOptimizer{Teval <: AbstractEvaluationType,
TM <: AbstractManifold, Tretr <: AbstractRetractionMethod,
Tvt <: AbstractVectorTransportMethod, TLS <: Stepsize} <:
TM <: AbstractManifold, Tretr <: AbstractRetractionMethod,
Tvt <: AbstractVectorTransportMethod, TLS <: Stepsize} <:
AbstractManoptOptimizer
M::TM
retraction_method::Tretr
Expand All @@ -188,43 +187,43 @@ struct QuasiNewtonOptimizer{Teval <: AbstractEvaluationType,
end

function QuasiNewtonOptimizer(M::AbstractManifold;
eval::AbstractEvaluationType = InplaceEvaluation(),
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
stepsize = WolfePowellLinesearch(M;
retraction_method = retraction_method,
vector_transport_method = vector_transport_method,
linesearch_stopsize = 1e-12))
eval::AbstractEvaluationType = InplaceEvaluation(),
retraction_method::AbstractRetractionMethod = default_retraction_method(M),
vector_transport_method::AbstractVectorTransportMethod = default_vector_transport_method(M),
stepsize = WolfePowellLinesearch(M;
retraction_method = retraction_method,
vector_transport_method = vector_transport_method,
linesearch_stopsize = 1e-12))
QuasiNewtonOptimizer{typeof(eval), typeof(M), typeof(retraction_method),
typeof(vector_transport_method), typeof(stepsize)}(M,
retraction_method,
vector_transport_method,
stepsize)
typeof(vector_transport_method), typeof(stepsize)}(M,
retraction_method,
vector_transport_method,
stepsize)
end

function call_manopt_optimizer(opt::QuasiNewtonOptimizer{Teval},
loss,
gradF,
x0,
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
Teval <:
AbstractEvaluationType
}
loss,
gradF,
x0,
stopping_criterion::Union{Nothing, Manopt.StoppingCriterion}) where {
Teval <:
AbstractEvaluationType
}
sckwarg = stopping_criterion_to_kwarg(stopping_criterion)
opts = quasi_Newton(opt.M,
loss,
gradF,
x0;
return_state = true,
evaluation = Teval(),
retraction_method = opt.retraction_method,
vector_transport_method = opt.vector_transport_method,
stepsize = opt.stepsize,
sckwarg...)
loss,
gradF,
x0;
return_state = true,
evaluation = Teval(),
retraction_method = opt.retraction_method,
vector_transport_method = opt.vector_transport_method,
stepsize = opt.stepsize,
sckwarg...)
# we unwrap DebugOptions here
minimizer = Manopt.get_solver_result(opts)
return (; minimizer = minimizer, minimum = loss(opt.M, minimizer), options = opts),
:who_knows
:who_knows
end

## Optimization.jl stuff
Expand Down Expand Up @@ -255,15 +254,15 @@ end
# 3) add callbacks to Manopt.jl

function SciMLBase.__solve(prob::OptimizationProblem,
opt::AbstractManoptOptimizer,
data = Optimization.DEFAULT_DATA;
callback = (args...) -> (false),
maxiters::Union{Number, Nothing} = nothing,
maxtime::Union{Number, Nothing} = nothing,
abstol::Union{Number, Nothing} = nothing,
reltol::Union{Number, Nothing} = nothing,
progress = false,
kwargs...)
opt::AbstractManoptOptimizer,
data = Optimization.DEFAULT_DATA;
callback = (args...) -> (false),
maxiters::Union{Number, Nothing} = nothing,
maxtime::Union{Number, Nothing} = nothing,
abstol::Union{Number, Nothing} = nothing,
reltol::Union{Number, Nothing} = nothing,
progress = false,
kwargs...)
local x, cur, state

manifold = haskey(prob.kwargs, :manifold) ? prob.kwargs[:manifold] : nothing
Expand Down Expand Up @@ -295,12 +294,12 @@ function SciMLBase.__solve(prob::OptimizationProblem,
opt_res, opt_ret = call_manopt_optimizer(opt, _loss, gradF, prob.u0, stopping_criterion)

return SciMLBase.build_solution(SciMLBase.DefaultOptimizationCache(prob.f, prob.p),
opt,
opt_res.minimizer,
prob.sense === Optimization.MaxSense ?
-opt_res.minimum : opt_res.minimum;
original = opt_res.options,
retcode = opt_ret)
opt,
opt_res.minimizer,
prob.sense === Optimization.MaxSense ?
-opt_res.minimum : opt_res.minimum;
original = opt_res.options,
retcode = opt_ret)
end

end # module OptimizationManopt
end # module OptimizationManopt
Loading

0 comments on commit bcbc649

Please sign in to comment.