Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return CGResult from cg and allow absolute tol #238

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/src/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ Most solvers contain the `log` keyword. This is to be used when obtaining
more information is required, to use it place the set `log` to `true`.

```julia
x, ch = cg(Master, rand(10, 10), rand(10) log=true)
r = cg(Master, rand(10, 10), rand(10) log=true)
x, ch = r.x, r.history
svd, L, ch = svdl(Master, rand(100, 100), log=true)
```

Expand Down
3 changes: 2 additions & 1 deletion docs/src/linear_systems/cg.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ n = 100
A = cu(rand(n, n))
A = A + A' + 2*n*I
b = cu(rand(n))
x = cg(A, b)
r = cg(A, b)
x = r.x
```

!!! note
Expand Down
69 changes: 53 additions & 16 deletions src/cg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mutable struct CGIterable{matT, solT, vecT, numT <: Real}
r::vecT
c::vecT
u::vecT
reltol::numT
tol::numT
residual::numT
prev_residual::numT
maxiter::Int
Expand All @@ -22,18 +22,49 @@ mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Numb
r::vecT
c::vecT
u::vecT
reltol::numT
tol::numT
residual::numT
ρ::paramT
maxiter::Int
mv_products::Int
end

@inline converged(it::Union{CGIterable, PCGIterable}) = it.residual ≤ it.reltol
struct CGResult{Tx, T, Thistory}
x::Tx
residual::T
tol::T
iterations::Int
maxiter::Int
converged::Bool
history::Thistory
end
function Base.show(io::IO, r::CGResult)
first_two(fr) = [x for (i, x) in enumerate(fr)][1:2]

@printf io "Result of CG Algorithm\n"
@printf io " * Algorithm: CG \n"

if length(join(r.x, ",")) < 40 || length(r.x) <= 2
@printf io " * x: [%s]\n" join(r.x, ",")
else
@printf io " * x: [%s, ...]\n" join(first_two(r.x), ",")
end

@printf io " * Convergence\n"
@printf io " * Residual: %s\n" r.residual
@printf io " * Tolerance: %s\n" r.tol
@printf io " * Converged: %s\n" r.converged
@printf io " * Iterations: %s\n" r.iterations
@printf io " * Iterations limit: %s\n" r.maxiter

return
end

@inline isconverged(it::Union{CGIterable, PCGIterable}) = it.residual ≤ it.tol

@inline start(it::Union{CGIterable, PCGIterable}) = 0

@inline done(it::Union{CGIterable, PCGIterable}, iteration::Int) = iteration ≥ it.maxiter || converged(it)
@inline done(it::Union{CGIterable, PCGIterable}, iteration::Int) = iteration ≥ it.maxiter || isconverged(it)


###############
Expand Down Expand Up @@ -114,7 +145,8 @@ struct CGStateVariables{T,Tx<:AbstractArray{T}}
end

function cg_iterator!(x, A, b, Pl = Identity();
tol = sqrt(eps(real(eltype(b)))),
reltol = sqrt(eps(real(eltype(b)))),
tol = zero(real(eltype(b))),
maxiter::Int = size(A, 2),
statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)),
initially_zero::Bool = false
Expand All @@ -130,24 +162,24 @@ function cg_iterator!(x, A, b, Pl = Identity();
mv_products = 0
c = similar(x)
residual = norm(b)
reltol = residual * tol # Save one dot product
tol = max(residual * reltol, tol) # Save one dot product
else
mv_products = 1
mul!(c, A, x)
r .-= c
residual = norm(r)
reltol = norm(b) * tol
tol = max(norm(b) * reltol, tol)
end

# Return the iterable
if isa(Pl, Identity)
return CGIterable(A, x, r, c, u,
reltol, residual, one(residual),
tol, residual, one(residual),
maxiter, mv_products
)
else
return PCGIterable(Pl, A, x, r, c, u,
reltol, residual, one(eltype(x)),
tol, residual, one(eltype(x)),
maxiter, mv_products
)
end
Expand Down Expand Up @@ -177,7 +209,8 @@ cg(A, b; kwargs...) = cg!(zerox(A, b), A, b; initially_zero = true, kwargs...)
residual vector;
- `Pl = Identity()`: left preconditioner of the method. Should be symmetric,
positive-definite like `A`;
- `tol::Real = sqrt(eps(real(eltype(b))))`: tolerance for stopping condition `|r_k| / |r_0| ≤ tol`;
- `reltol::Real = sqrt(eps(real(eltype(b))))`: relative tolerance for stopping condition `|r_k| / |r_0| ≤ reltol`;
- `tol` = zero(real(eltype(b))): tolerance for stopping condition `|r_k| ≤ tol`,
- `maxiter::Int = size(A,2)`: maximum number of iterations;
- `verbose::Bool = false`: print method information;
- `log::Bool = false`: keep track of the residual norm in each iteration.
Expand All @@ -199,7 +232,8 @@ cg(A, b; kwargs...) = cg!(zerox(A, b), A, b; initially_zero = true, kwargs...)
- `:resnom` => `::Vector`: residual norm at each iteration.
"""
function cg!(x, A, b;
tol = sqrt(eps(real(eltype(b)))),
reltol = sqrt(eps(real(eltype(b)))),
tol = zero(real(eltype(b))),
maxiter::Int = size(A, 2),
log::Bool = false,
statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)),
Expand All @@ -208,15 +242,17 @@ function cg!(x, A, b;
kwargs...
)
history = ConvergenceHistory(partial = !log)
history[:tol] = tol
log && reserve!(history, :resnorm, maxiter + 1)

# Actually perform CG
iterable = cg_iterator!(x, A, b, Pl; tol = tol, maxiter = maxiter, statevars = statevars, kwargs...)
iterable = cg_iterator!(x, A, b, Pl; tol = tol, reltol = reltol, maxiter = maxiter, statevars = statevars, kwargs...)
history[:tol] = iterable.tol
if log
history.mvps = iterable.mv_products
end
for (iteration, item) = enumerate(iterable)
iteration = 0
for item in iterable
iteration += 1
if log
nextiter!(history, mvps = 1)
push!(history, :resnorm, iterable.residual)
Expand All @@ -225,8 +261,9 @@ function cg!(x, A, b;
end

verbose && println()
log && setconv(history, converged(iterable))
converged = isconverged(iterable)
log && setconv(history, converged)
log && shrink!(history)

log ? (iterable.x, history) : iterable.x
return CGResult(iterable.x, iterable.residual, iterable.tol, iteration, maxiter, converged, history)
end
38 changes: 25 additions & 13 deletions test/cg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,31 @@ Random.seed!(1234321)
A = rand(T, n, n)
A = A' * A + I
b = rand(T, n)
tol = √eps(real(T))
reltol = √eps(real(T))

x,ch = cg(A, b; tol=tol, maxiter=2n, log=true)
r = cg(A, b; reltol=reltol, maxiter=2n, log=true)
x, ch = r.x, r.history
@test isa(ch, ConvergenceHistory)
@test norm(A*x - b) / norm(b) ≤ tol
@test norm(A*x - b) / norm(b) ≤ reltol
@test norm(A*x - b) ≤ r.tol
@test ch.isconverged

# If you start from the exact solution, you should converge immediately
x,ch = cg!(A \ b, A, b; tol=10tol, log=true)
r = cg!(A \ b, A, b; reltol=10reltol, log=true)
x, ch = r.x, r.history
@test niters(ch) ≤ 1
@test nprods(ch) ≤ 2

# Test with cholfact should converge immediately
F = cholesky(A, Val(false))
x,ch = cg(A, b; Pl=F, log=true)
r = cg(A, b; Pl=F, log=true)
x, ch = r.x, r.history
@test niters(ch) ≤ 2
@test nprods(ch) ≤ 2

# All-zeros rhs should give all-zeros lhs
x0 = cg(A, zeros(T, n))
r = cg(A, zeros(T, n))
x0 = r.x
@test x0 == zeros(T, n)
end
end
Expand All @@ -59,24 +64,30 @@ end
tol = 1e-5

@testset "SparseMatrixCSC{$T, $Ti}" for T in (Float64, Float32), Ti in (Int64, Int32)
xCG = cg(A, rhs; tol=tol, maxiter=100)
xJAC = cg(A, rhs; Pl=P, tol=tol, maxiter=100)
r = cg(A, rhs; tol=tol, maxiter=100)
xCG = r.x
r = cg(A, rhs; Pl=P, tol=tol, maxiter=100)
xJAC = r.x
@test norm(A * xCG - rhs) ≤ tol
@test norm(A * xJAC - rhs) ≤ tol
end

Af = LinearMap(A)
@testset "Function" begin
xCG = cg(Af, rhs; tol=tol, maxiter=100)
xJAC = cg(Af, rhs; Pl=P, tol=tol, maxiter=100)
r = cg(Af, rhs; tol=tol, maxiter=100)
xCG = r.x
r = cg(Af, rhs; Pl=P, tol=tol, maxiter=100)
xJAC = r.x
@test norm(A * xCG - rhs) ≤ tol
@test norm(A * xJAC - rhs) ≤ tol
end

@testset "Function with specified starting guess" begin
x0 = randn(size(rhs))
xCG, hCG = cg!(copy(x0), Af, rhs; tol=tol, maxiter=100, log=true)
xJAC, hJAC = cg!(copy(x0), Af, rhs; Pl=P, tol=tol, maxiter=100, log=true)
r = cg!(copy(x0), Af, rhs; tol=tol, maxiter=100, log=true)
xCG, hCG = r.x, r.history
r = cg!(copy(x0), Af, rhs; Pl=P, tol=tol, maxiter=100, log=true)
xJAC, hJAC = r.x, r.history
@test norm(A * xCG - rhs) ≤ tol
@test norm(A * xJAC - rhs) ≤ tol
@test niters(hJAC) == niters(hCG)
Expand All @@ -88,7 +99,8 @@ end
A = A + A' + 100I
x = view(rand(10, 2), :, 1)
b = rand(10)
x, hist = cg!(x, A, b, log = true)
r = cg!(x, A, b, log = true)
x, hist = r.x, r.history
@test hist.isconverged
end

Expand Down