Skip to content

Commit

Permalink
Improvements to cholesky rrules (#630)
Browse files Browse the repository at this point in the history
* Rewrite getproperty rule to store factors

* Work with factors directly

* Create tangent with factors

* Simplify and generalize cholesky number rule

* Use default tangent

* Generalize diagonal cholesky to Hermitian

* Simplify cholesky(::Diagonal) tests

* Generalize and simplify cholesky(::StridedMatrix)

* Fixes for Hermitian matrices

* Generalize to complex Hermitian matrices

* Remove unnecessary single-arg rule

* Reformat

* Check that check kwarg correctly passed

* Support failed factorizations

* Remove specializations for Thunks

* Release unnecessary constraints on factors

* Decrease step size

* Check complex cotangent for real primal works

* Fix diagonal rule for failed factorization

* Release type constraint of Diagonal

* Refer to real instead off complex

* Increment patch number

* Avoid unnecessary copies

* Update src/rulesets/LinearAlgebra/factorization.jl

Co-authored-by: David Widmann <[email protected]>

* Apply suggestions from code review

Co-authored-by: Frames Catherine White <[email protected]>

* Complexify with concrete types

Co-authored-by: David Widmann <[email protected]>
Co-authored-by: Frames Catherine White <[email protected]>
  • Loading branch information
3 people authored Jun 17, 2022
1 parent a0d86fe commit 6ff4c31
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 94 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.35.2"
version = "1.35.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
115 changes: 57 additions & 58 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -441,46 +441,30 @@ end
##### `cholesky`
#####

# these functions are defined outside the rrule because otherwise type inference breaks
# see https://github.com/JuliaLang/julia/issues/40990
_cholesky_real_pullback(ΔC::Tangent, full_pb) = return full_pb(ΔC)[1:2]
function _cholesky_real_pullback(Ȳ::AbstractThunk, full_pb)
return _cholesky_real_pullback(unthunk(Ȳ), full_pb)
end
function rrule(::typeof(cholesky),
A::Union{
Real,
Diagonal{<:Real},
LinearAlgebra.HermOrSym{<:LinearAlgebra.BlasReal,<:StridedMatrix},
StridedMatrix{<:LinearAlgebra.BlasReal}
}
# Handle not passing in the uplo
)
arg2 = A isa Real ? :U : Val(false)
C, full_pb = rrule(cholesky, A, arg2)

cholesky_pullback(ȳ) = return _cholesky_real_pullback(ȳ, full_pb)
return C, cholesky_pullback
end

function _cholesky_realuplo_pullback(ΔC::Tangent, C)
return NoTangent(), ΔC.factors[1, 1] / (2 * C.U[1, 1]), NoTangent()
end
_cholesky_realuplo_pullback(Ȳ::AbstractThunk, C) = _cholesky_realuplo_pullback(unthunk(Ȳ), C)
function rrule(::typeof(cholesky), A::Real, uplo::Symbol)
C = cholesky(A, uplo)
cholesky_pullback(ȳ) = _cholesky_realuplo_pullback(ȳ, C)
function rrule(::typeof(cholesky), x::Number, uplo::Symbol)
C = cholesky(x, uplo)
function cholesky_pullback(ΔC)
= real(only(unthunk(ΔC).factors)) / (2 * sign(real(x)) * only(C.factors))
return NoTangent(), Ā, NoTangent()
end
return C, cholesky_pullback
end

function _cholesky_Diagonal_pullback(ΔC::Tangent, C)
= Diagonal(diag(ΔC.factors) .* inv.(2 .* C.factors.diag))
return NoTangent(), Ā, NoTangent()
function _cholesky_Diagonal_pullback(ΔC, C)
Udiag = C.factors.diag
ΔUdiag = diag(ΔC.factors)
Ādiag = real.(ΔUdiag) ./ (2 .* Udiag)
if !issuccess(C)
# cholesky computes the factor diagonal from the beginning until it encounters the
# first failure. The remainder of the diagonal is then copied from the input.
i = findfirst(x -> !isreal(x) || !(real(x) > 0), Udiag)
Ādiag[i:end] .= ΔUdiag[i:end]
end
return NoTangent(), Diagonal(Ādiag), NoTangent()
end
_cholesky_Diagonal_pullback(Ȳ::AbstractThunk, C) = _cholesky_Diagonal_pullback(unthunk(Ȳ), C)
function rrule(::typeof(cholesky), A::Diagonal{<:Real}, ::Val{false}; check::Bool=true)
function rrule(::typeof(cholesky), A::Diagonal{<:Number}, ::Val{false}; check::Bool=true)
C = cholesky(A, Val(false); check=check)
cholesky_pullback(ȳ) = _cholesky_Diagonal_pullback(, C)
cholesky_pullback(ȳ) = _cholesky_Diagonal_pullback(unthunk(ȳ), C)
return C, cholesky_pullback
end

Expand All @@ -489,69 +473,84 @@ end
# Implementation due to Seeger, Matthias, et al. "Auto-differentiating linear algebra."
function rrule(
::typeof(cholesky),
A::LinearAlgebra.HermOrSym{<:LinearAlgebra.BlasReal, <:StridedMatrix},
A::LinearAlgebra.RealHermSymComplexHerm{<:Real, <:StridedMatrix},
::Val{false};
check::Bool=true,
)
C = cholesky(A, Val(false); check=check)
function _cholesky_HermOrSym_pullback(ΔC::Tangent)
, U = _cholesky_pullback_shared_code(C, ΔC)
= BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)) / 2, U.data, Ā)
function cholesky_HermOrSym_pullback(ΔC)
= _cholesky_pullback_shared_code(C, unthunk(ΔC))
rmul!(Ā, one(eltype(Ā)) / 2)
return NoTangent(), _symhermtype(A)(Ā), NoTangent()
end
_cholesky_HermOrSym_pullback(Ȳ::AbstractThunk) = _cholesky_HermOrSym_pullback(unthunk(Ȳ))
return C, _cholesky_HermOrSym_pullback
return C, cholesky_HermOrSym_pullback
end

function rrule(
::typeof(cholesky),
A::StridedMatrix{<:LinearAlgebra.BlasReal},
A::StridedMatrix{<:Union{Real,Complex}},
::Val{false};
check::Bool=true,
)
C = cholesky(A, Val(false); check=check)
function _cholesky_Strided_pullback(ΔC::Tangent)
Ā, U = _cholesky_pullback_shared_code(C, ΔC)
= BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)), U.data, Ā)
function cholesky_Strided_pullback(ΔC)
= _cholesky_pullback_shared_code(C, unthunk(ΔC))
idx = diagind(Ā)
@views Ā[idx] .= real.(Ā[idx]) ./ 2
return (NoTangent(), UpperTriangular(Ā), NoTangent())
end
_cholesky_Strided_pullback(Ȳ::AbstractThunk) = _cholesky_Strided_pullback(unthunk(Ȳ))
return C, _cholesky_Strided_pullback
return C, cholesky_Strided_pullback
end

function _cholesky_pullback_shared_code(C, ΔC)
U = C.U
= ΔC.U
= similar(U.data)
= mul!(Ā, Ū, U')
= LinearAlgebra.copytri!(Ā, 'U', true)
= ldiv!(U, Ā)
return Ā, U
Δfactors = ΔC.factors
= similar(C.factors)
if C.uplo === 'U'
U = C.U
= eltype(U) <: Real ? real(_maybeUpperTri(Δfactors)) : _maybeUpperTri(Δfactors)
mul!(Ā, Ū, U')
LinearAlgebra.copytri!(Ā, 'U', true)
eltype(Ā) <: Real || _realifydiag!(Ā)
ldiv!(U, Ā)
rdiv!(Ā, U')
else # C.uplo === 'L'
L = C.L
= eltype(L) <: Real ? real(_maybeLowerTri(Δfactors)) : _maybeLowerTri(Δfactors)
mul!(Ā, L', L̄)
LinearAlgebra.copytri!(Ā, 'L', true)
eltype(Ā) <: Real || _realifydiag!(Ā)
rdiv!(Ā, L)
ldiv!(L', Ā)
end
return
end

function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky}
function getproperty_cholesky_pullback(Ȳ)
C = Tangent{T}
∂F = if x === :U
if F.uplo === 'U'
C(U=UpperTriangular(Ȳ),)
C(factors=_maybeUpperTri(Ȳ),)
else
C(L=LowerTriangular'),)
C(factors=_maybeLowerTri'),)
end
elseif x === :L
if F.uplo === 'L'
C(L=LowerTriangular(Ȳ),)
C(factors=_maybeLowerTri(Ȳ),)
else
C(U=UpperTriangular'),)
C(factors=_maybeUpperTri'),)
end
end
return NoTangent(), ∂F, NoTangent()
end
return getproperty(F, x), getproperty_cholesky_pullback
end

_maybeUpperTri(A) = UpperTriangular(A)
_maybeUpperTri(A::Diagonal) = A
_maybeLowerTri(A) = LowerTriangular(A)
_maybeLowerTri(A::Diagonal) = A

# `det` and `logdet` for `Cholesky`
function rrule(::typeof(det), C::Cholesky)
y = det(C)
Expand Down
125 changes: 90 additions & 35 deletions test/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,57 +382,112 @@ end
# have fantastic support for this stuff at the minute.
# also we might be missing some overloads for different tangent-types in the rules
@testset "cholesky" begin
@testset "Real" begin
test_rrule(cholesky, 0.8)
@testset "Number" begin
@testset "uplo=$uplo" for uplo in (:U, :L)
test_rrule(cholesky, 0.8, uplo)
test_rrule(cholesky, -0.3, uplo)
test_rrule(cholesky, 0.23 + 0im, uplo)
test_rrule(cholesky, 0.78 + 0.5im, uplo)
test_rrule(cholesky, -0.34 + 0.1im, uplo)
end
end
@testset "Diagonal{<:Real}" begin
D = Diagonal(rand(5) .+ 0.1)
C = cholesky(D)
test_rrule(
cholesky, D Diagonal(randn(5)), Val(false);
output_tangent=Tangent{typeof(C)}(factors=Diagonal(randn(5)))
)

@testset "Diagonal" begin
@testset "Diagonal{<:Real}" begin
test_rrule(cholesky, Diagonal([0.3, 0.2, 0.5, 0.6, 0.9]), Val(false))
end
@testset "Diagonal{<:Complex}" begin
# finite differences in general will produce matrices with non-real
# diagonals, which cause factorization to fail. If we turn off the check and
# ensure the cotangent is real, then test_rrule still works.
D = Diagonal([0.3 + 0im, 0.2, 0.5, 0.6, 0.9])
C = cholesky(D)
test_rrule(
cholesky, D, Val(false);
output_tangent=Tangent{typeof(C)}(factors=complex(randn(5, 5))),
fkwargs=(; check=false),
)
end
@testset "check has correct default and passed to primal" begin
@test_throws Exception rrule(cholesky, Diagonal(-rand(5)), Val(false))
rrule(cholesky, Diagonal(-rand(5)), Val(false); check=false)
end
@testset "failed factorization" begin
A = Diagonal(vcat(rand(4), -rand(4), rand(4)))
test_rrule(cholesky, A, Val(false); fkwargs=(; check=false))
end
end

X = generate_well_conditioned_matrix(10)
V = generate_well_conditioned_matrix(10)
F, dX_pullback = rrule(cholesky, X, Val(false))
F_1arg, dX_pullback_1arg = rrule(cholesky, X) # to test not passing the Val(false)
@test F == F_1arg
@testset "uplo=$p" for p in [:U, :L]
Y, dF_pullback = rrule(getproperty, F, p)
= (p === :U ? UpperTriangular : LowerTriangular)(randn(size(Y)))
(dself, dF, dp) = dF_pullback(Ȳ)
@test dself === NoTangent()
@test dp === NoTangent()

# NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp`
# machinery from FiniteDifferences because that isn't set up to respect
# necessary special properties of the input. In the case of the Cholesky
# factorization, we need the input to be Hermitian.
ΔF = unthunk(dF)
_, dX, darg2 = dX_pullback(ΔF)
_, dX_1arg = dX_pullback_1arg(ΔF)
@test dX == dX_1arg
@test darg2 === NoTangent()
X̄_ad = dot(unthunk(dX), V)
X̄_fd = central_fdm(5, 1)(0.000_001) do ε
dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p))
@testset "StridedMatrix" begin
@testset "Matrix{$T}" for T in (Float64, ComplexF64)
X = generate_well_conditioned_matrix(T, 10)
V = generate_well_conditioned_matrix(T, 10)
F, dX_pullback = rrule(cholesky, X, Val(false))
@testset "uplo=$p, cotangent eltype=$T" for p in [:U, :L], S in unique([T, complex(T)])
Y, dF_pullback = rrule(getproperty, F, p)
= randn(S, size(Y))
(dself, dF, dp) = dF_pullback(Ȳ)
@test dself === NoTangent()
@test dp === NoTangent()

# NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp`
# machinery from FiniteDifferences because that isn't set up to respect
# necessary special properties of the input. In the case of the Cholesky
# factorization, we need the input to be Hermitian.
ΔF = unthunk(dF)
_, dX, darg2 = dX_pullback(ΔF)
@test darg2 === NoTangent()
X̄_ad = real(dot(unthunk(dX), V))
X̄_fd = central_fdm(5, 1)(0.000_0001) do ε
real(dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p)))
end
@test X̄_ad X̄_fd rtol=1e-4
end
end
@testset "check has correct default and passed to primal" begin
# this will almost certainly be a non-PD matrix
X = Matrix(Symmetric(randn(10, 10)))
@test_throws Exception rrule(cholesky, X, Val(false))
rrule(cholesky, X, Val(false); check=false) # just check it doesn't throw
end
@test X̄_ad X̄_fd rtol=1e-4
end

# Ensure that cotangents of cholesky(::StridedMatrix) and
# (cholesky ∘ Symmetric)(::StridedMatrix) are equal.
@testset "Symmetric" begin
X = generate_well_conditioned_matrix(10)
F, dX_pullback = rrule(cholesky, X, Val(false))

X_symmetric, sym_back = rrule(Symmetric, X, :U)
C, chol_back_sym = rrule(cholesky, X_symmetric, Val(false))

Δ = Tangent{typeof(C)}((U=UpperTriangular(randn(size(X)))))
Δ = Tangent{typeof(C)}((factors=randn(size(X))))
ΔX_symmetric = chol_back_sym(Δ)[2]
@test sym_back(ΔX_symmetric)[2] dX_pullback(Δ)[2]
end

# Ensure that cotangents of cholesky(::StridedMatrix) and
# (cholesky ∘ Hermitian)(::StridedMatrix) are equal.
@testset "Hermitian" begin
@testset "Hermitian{$T}" for T in (Float64, ComplexF64)
X = generate_well_conditioned_matrix(T, 10)
F, dX_pullback = rrule(cholesky, X, Val(false))

X_hermitian, herm_back = rrule(Hermitian, X, :U)
C, chol_back_herm = rrule(cholesky, X_hermitian, Val(false))

Δ = Tangent{typeof(C)}((factors=randn(T, size(X))))
ΔX_hermitian = chol_back_herm(Δ)[2]
@test herm_back(ΔX_hermitian)[2] dX_pullback(Δ)[2]
end
@testset "check has correct default and passed to primal" begin
# this will almost certainly be a non-PD matrix
X = Hermitian(randn(10, 10))
@test_throws Exception rrule(cholesky, X, Val(false))
rrule(cholesky, X, Val(false); check=false)
end
end

@testset "det and logdet (uplo=$p)" for p in (:U, :L)
@testset "$op" for op in (det, logdet)
@testset "$T" for T in (Float64, ComplexF64)
Expand Down

2 comments on commit 6ff4c31

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/62507

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.35.3 -m "<description of version>" 6ff4c319f8fd25f27636d28144d78c92f81d8753
git push origin v1.35.3

Please sign in to comment.