diff --git a/Project.toml b/Project.toml index ec0410ffe..2433a21c1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.35.2" +version = "1.35.3" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index e47d4b2c2..f2f8d32d1 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -441,46 +441,30 @@ end ##### `cholesky` ##### -# these functions are defined outside the rrule because otherwise type inference breaks -# see https://github.com/JuliaLang/julia/issues/40990 -_cholesky_real_pullback(ΔC::Tangent, full_pb) = return full_pb(ΔC)[1:2] -function _cholesky_real_pullback(Ȳ::AbstractThunk, full_pb) - return _cholesky_real_pullback(unthunk(Ȳ), full_pb) -end -function rrule(::typeof(cholesky), - A::Union{ - Real, - Diagonal{<:Real}, - LinearAlgebra.HermOrSym{<:LinearAlgebra.BlasReal,<:StridedMatrix}, - StridedMatrix{<:LinearAlgebra.BlasReal} - } - # Handle not passing in the uplo -) - arg2 = A isa Real ? :U : Val(false) - C, full_pb = rrule(cholesky, A, arg2) - - cholesky_pullback(ȳ) = return _cholesky_real_pullback(ȳ, full_pb) - return C, cholesky_pullback -end - -function _cholesky_realuplo_pullback(ΔC::Tangent, C) - return NoTangent(), ΔC.factors[1, 1] / (2 * C.U[1, 1]), NoTangent() -end -_cholesky_realuplo_pullback(Ȳ::AbstractThunk, C) = _cholesky_realuplo_pullback(unthunk(Ȳ), C) -function rrule(::typeof(cholesky), A::Real, uplo::Symbol) - C = cholesky(A, uplo) - cholesky_pullback(ȳ) = _cholesky_realuplo_pullback(ȳ, C) +function rrule(::typeof(cholesky), x::Number, uplo::Symbol) + C = cholesky(x, uplo) + function cholesky_pullback(ΔC) + Ā = real(only(unthunk(ΔC).factors)) / (2 * sign(real(x)) * only(C.factors)) + return NoTangent(), Ā, NoTangent() + end return C, cholesky_pullback end -function _cholesky_Diagonal_pullback(ΔC::Tangent, C) - Ā = Diagonal(diag(ΔC.factors) .* inv.(2 .* C.factors.diag)) - return NoTangent(), Ā, NoTangent() +function _cholesky_Diagonal_pullback(ΔC, C) + Udiag = C.factors.diag + ΔUdiag = diag(ΔC.factors) + Ādiag = real.(ΔUdiag) ./ (2 .* Udiag) + if !issuccess(C) + # cholesky computes the factor diagonal from the beginning until it encounters the + # first failure. The remainder of the diagonal is then copied from the input. + i = findfirst(x -> !isreal(x) || !(real(x) > 0), Udiag) + Ādiag[i:end] .= ΔUdiag[i:end] + end + return NoTangent(), Diagonal(Ādiag), NoTangent() end -_cholesky_Diagonal_pullback(Ȳ::AbstractThunk, C) = _cholesky_Diagonal_pullback(unthunk(Ȳ), C) -function rrule(::typeof(cholesky), A::Diagonal{<:Real}, ::Val{false}; check::Bool=true) +function rrule(::typeof(cholesky), A::Diagonal{<:Number}, ::Val{false}; check::Bool=true) C = cholesky(A, Val(false); check=check) - cholesky_pullback(ȳ) = _cholesky_Diagonal_pullback(ȳ, C) + cholesky_pullback(ȳ) = _cholesky_Diagonal_pullback(unthunk(ȳ), C) return C, cholesky_pullback end @@ -489,46 +473,56 @@ end # Implementation due to Seeger, Matthias, et al. "Auto-differentiating linear algebra." function rrule( ::typeof(cholesky), - A::LinearAlgebra.HermOrSym{<:LinearAlgebra.BlasReal, <:StridedMatrix}, + A::LinearAlgebra.RealHermSymComplexHerm{<:Real, <:StridedMatrix}, ::Val{false}; check::Bool=true, ) C = cholesky(A, Val(false); check=check) - function _cholesky_HermOrSym_pullback(ΔC::Tangent) - Ā, U = _cholesky_pullback_shared_code(C, ΔC) - Ā = BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)) / 2, U.data, Ā) + function cholesky_HermOrSym_pullback(ΔC) + Ā = _cholesky_pullback_shared_code(C, unthunk(ΔC)) + rmul!(Ā, one(eltype(Ā)) / 2) return NoTangent(), _symhermtype(A)(Ā), NoTangent() end - _cholesky_HermOrSym_pullback(Ȳ::AbstractThunk) = _cholesky_HermOrSym_pullback(unthunk(Ȳ)) - return C, _cholesky_HermOrSym_pullback + return C, cholesky_HermOrSym_pullback end function rrule( ::typeof(cholesky), - A::StridedMatrix{<:LinearAlgebra.BlasReal}, + A::StridedMatrix{<:Union{Real,Complex}}, ::Val{false}; check::Bool=true, ) C = cholesky(A, Val(false); check=check) - function _cholesky_Strided_pullback(ΔC::Tangent) - Ā, U = _cholesky_pullback_shared_code(C, ΔC) - Ā = BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)), U.data, Ā) + function cholesky_Strided_pullback(ΔC) + Ā = _cholesky_pullback_shared_code(C, unthunk(ΔC)) idx = diagind(Ā) @views Ā[idx] .= real.(Ā[idx]) ./ 2 return (NoTangent(), UpperTriangular(Ā), NoTangent()) end - _cholesky_Strided_pullback(Ȳ::AbstractThunk) = _cholesky_Strided_pullback(unthunk(Ȳ)) - return C, _cholesky_Strided_pullback + return C, cholesky_Strided_pullback end function _cholesky_pullback_shared_code(C, ΔC) - U = C.U - Ū = ΔC.U - Ā = similar(U.data) - Ā = mul!(Ā, Ū, U') - Ā = LinearAlgebra.copytri!(Ā, 'U', true) - Ā = ldiv!(U, Ā) - return Ā, U + Δfactors = ΔC.factors + Ā = similar(C.factors) + if C.uplo === 'U' + U = C.U + Ū = eltype(U) <: Real ? real(_maybeUpperTri(Δfactors)) : _maybeUpperTri(Δfactors) + mul!(Ā, Ū, U') + LinearAlgebra.copytri!(Ā, 'U', true) + eltype(Ā) <: Real || _realifydiag!(Ā) + ldiv!(U, Ā) + rdiv!(Ā, U') + else # C.uplo === 'L' + L = C.L + L̄ = eltype(L) <: Real ? real(_maybeLowerTri(Δfactors)) : _maybeLowerTri(Δfactors) + mul!(Ā, L', L̄) + LinearAlgebra.copytri!(Ā, 'L', true) + eltype(Ā) <: Real || _realifydiag!(Ā) + rdiv!(Ā, L) + ldiv!(L', Ā) + end + return Ā end function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky} @@ -536,15 +530,15 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky} C = Tangent{T} ∂F = if x === :U if F.uplo === 'U' - C(U=UpperTriangular(Ȳ),) + C(factors=_maybeUpperTri(Ȳ),) else - C(L=LowerTriangular(Ȳ'),) + C(factors=_maybeLowerTri(Ȳ'),) end elseif x === :L if F.uplo === 'L' - C(L=LowerTriangular(Ȳ),) + C(factors=_maybeLowerTri(Ȳ),) else - C(U=UpperTriangular(Ȳ'),) + C(factors=_maybeUpperTri(Ȳ'),) end end return NoTangent(), ∂F, NoTangent() @@ -552,6 +546,11 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky} return getproperty(F, x), getproperty_cholesky_pullback end +_maybeUpperTri(A) = UpperTriangular(A) +_maybeUpperTri(A::Diagonal) = A +_maybeLowerTri(A) = LowerTriangular(A) +_maybeLowerTri(A::Diagonal) = A + # `det` and `logdet` for `Cholesky` function rrule(::typeof(det), C::Cholesky) y = det(C) diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 2f6c4599a..2973ba893 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -382,57 +382,112 @@ end # have fantastic support for this stuff at the minute. # also we might be missing some overloads for different tangent-types in the rules @testset "cholesky" begin - @testset "Real" begin - test_rrule(cholesky, 0.8) + @testset "Number" begin + @testset "uplo=$uplo" for uplo in (:U, :L) + test_rrule(cholesky, 0.8, uplo) + test_rrule(cholesky, -0.3, uplo) + test_rrule(cholesky, 0.23 + 0im, uplo) + test_rrule(cholesky, 0.78 + 0.5im, uplo) + test_rrule(cholesky, -0.34 + 0.1im, uplo) + end end - @testset "Diagonal{<:Real}" begin - D = Diagonal(rand(5) .+ 0.1) - C = cholesky(D) - test_rrule( - cholesky, D ⊢ Diagonal(randn(5)), Val(false); - output_tangent=Tangent{typeof(C)}(factors=Diagonal(randn(5))) - ) + + @testset "Diagonal" begin + @testset "Diagonal{<:Real}" begin + test_rrule(cholesky, Diagonal([0.3, 0.2, 0.5, 0.6, 0.9]), Val(false)) + end + @testset "Diagonal{<:Complex}" begin + # finite differences in general will produce matrices with non-real + # diagonals, which cause factorization to fail. If we turn off the check and + # ensure the cotangent is real, then test_rrule still works. + D = Diagonal([0.3 + 0im, 0.2, 0.5, 0.6, 0.9]) + C = cholesky(D) + test_rrule( + cholesky, D, Val(false); + output_tangent=Tangent{typeof(C)}(factors=complex(randn(5, 5))), + fkwargs=(; check=false), + ) + end + @testset "check has correct default and passed to primal" begin + @test_throws Exception rrule(cholesky, Diagonal(-rand(5)), Val(false)) + rrule(cholesky, Diagonal(-rand(5)), Val(false); check=false) + end + @testset "failed factorization" begin + A = Diagonal(vcat(rand(4), -rand(4), rand(4))) + test_rrule(cholesky, A, Val(false); fkwargs=(; check=false)) + end end - X = generate_well_conditioned_matrix(10) - V = generate_well_conditioned_matrix(10) - F, dX_pullback = rrule(cholesky, X, Val(false)) - F_1arg, dX_pullback_1arg = rrule(cholesky, X) # to test not passing the Val(false) - @test F == F_1arg - @testset "uplo=$p" for p in [:U, :L] - Y, dF_pullback = rrule(getproperty, F, p) - Ȳ = (p === :U ? UpperTriangular : LowerTriangular)(randn(size(Y))) - (dself, dF, dp) = dF_pullback(Ȳ) - @test dself === NoTangent() - @test dp === NoTangent() - - # NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp` - # machinery from FiniteDifferences because that isn't set up to respect - # necessary special properties of the input. In the case of the Cholesky - # factorization, we need the input to be Hermitian. - ΔF = unthunk(dF) - _, dX, darg2 = dX_pullback(ΔF) - _, dX_1arg = dX_pullback_1arg(ΔF) - @test dX == dX_1arg - @test darg2 === NoTangent() - X̄_ad = dot(unthunk(dX), V) - X̄_fd = central_fdm(5, 1)(0.000_001) do ε - dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p)) + @testset "StridedMatrix" begin + @testset "Matrix{$T}" for T in (Float64, ComplexF64) + X = generate_well_conditioned_matrix(T, 10) + V = generate_well_conditioned_matrix(T, 10) + F, dX_pullback = rrule(cholesky, X, Val(false)) + @testset "uplo=$p, cotangent eltype=$T" for p in [:U, :L], S in unique([T, complex(T)]) + Y, dF_pullback = rrule(getproperty, F, p) + Ȳ = randn(S, size(Y)) + (dself, dF, dp) = dF_pullback(Ȳ) + @test dself === NoTangent() + @test dp === NoTangent() + + # NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp` + # machinery from FiniteDifferences because that isn't set up to respect + # necessary special properties of the input. In the case of the Cholesky + # factorization, we need the input to be Hermitian. + ΔF = unthunk(dF) + _, dX, darg2 = dX_pullback(ΔF) + @test darg2 === NoTangent() + X̄_ad = real(dot(unthunk(dX), V)) + X̄_fd = central_fdm(5, 1)(0.000_0001) do ε + real(dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p))) + end + @test X̄_ad ≈ X̄_fd rtol=1e-4 + end + end + @testset "check has correct default and passed to primal" begin + # this will almost certainly be a non-PD matrix + X = Matrix(Symmetric(randn(10, 10))) + @test_throws Exception rrule(cholesky, X, Val(false)) + rrule(cholesky, X, Val(false); check=false) # just check it doesn't throw end - @test X̄_ad ≈ X̄_fd rtol=1e-4 end # Ensure that cotangents of cholesky(::StridedMatrix) and # (cholesky ∘ Symmetric)(::StridedMatrix) are equal. @testset "Symmetric" begin + X = generate_well_conditioned_matrix(10) + F, dX_pullback = rrule(cholesky, X, Val(false)) + X_symmetric, sym_back = rrule(Symmetric, X, :U) C, chol_back_sym = rrule(cholesky, X_symmetric, Val(false)) - Δ = Tangent{typeof(C)}((U=UpperTriangular(randn(size(X))))) + Δ = Tangent{typeof(C)}((factors=randn(size(X)))) ΔX_symmetric = chol_back_sym(Δ)[2] @test sym_back(ΔX_symmetric)[2] ≈ dX_pullback(Δ)[2] end + # Ensure that cotangents of cholesky(::StridedMatrix) and + # (cholesky ∘ Hermitian)(::StridedMatrix) are equal. + @testset "Hermitian" begin + @testset "Hermitian{$T}" for T in (Float64, ComplexF64) + X = generate_well_conditioned_matrix(T, 10) + F, dX_pullback = rrule(cholesky, X, Val(false)) + + X_hermitian, herm_back = rrule(Hermitian, X, :U) + C, chol_back_herm = rrule(cholesky, X_hermitian, Val(false)) + + Δ = Tangent{typeof(C)}((factors=randn(T, size(X)))) + ΔX_hermitian = chol_back_herm(Δ)[2] + @test herm_back(ΔX_hermitian)[2] ≈ dX_pullback(Δ)[2] + end + @testset "check has correct default and passed to primal" begin + # this will almost certainly be a non-PD matrix + X = Hermitian(randn(10, 10)) + @test_throws Exception rrule(cholesky, X, Val(false)) + rrule(cholesky, X, Val(false); check=false) + end + end + @testset "det and logdet (uplo=$p)" for p in (:U, :L) @testset "$op" for op in (det, logdet) @testset "$T" for T in (Float64, ComplexF64)