diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 89a6c22..071ac47 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -18,8 +18,8 @@ jobs: fail-fast: false matrix: version: - # - '1.6' - '1.7' + - '1.8' # - 'nightly' os: - ubuntu-latest diff --git a/Project.toml b/Project.toml index 961d2b0..7840aea 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ImplicitDifferentiation" uuid = "57b37032-215b-411a-8a7c-41a003a55207" authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"] -version = "0.2.0" +version = "0.3.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/implicit_function.jl b/src/implicit_function.jl index 7d9061b..95177f3 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -30,39 +30,47 @@ function ImplicitFunction(forward::F, conditions::C) where {F,C} return ImplicitFunction(forward, conditions, gmres) end -struct SolverFailureException <: Exception +struct SolverFailureException{S} <: Exception msg::String + stats::S +end + +function Base.show(io::IO, sfe::SolverFailureException) + println(io, "SolverFailureException: $(sfe.msg) \n Solver stats: $(sfe.stats)") end """ - implicit(x) + implicit(x[; kwargs...]) Make [`ImplicitFunction{F,C,L}`](@ref) callable by applying `implicit.forward`. """ -(implicit::ImplicitFunction)(x) = implicit.forward(x) +(implicit::ImplicitFunction)(x; kwargs...) = implicit.forward(x; kwargs...) """ - frule(rc, (_, dx), implicit, x) + frule(rc, (_, dx), implicit, x[; kwargs...]) Custom forward rule for [`ImplicitFunction{F,C,L}`](@ref). We compute the Jacobian-vector product `Jv` by solving `Au = Bv` and setting `Jv = u`. +Keyword arguments are given to both `implicit.forward` and `implicit.conditions`. """ function ChainRulesCore.frule( - rc::RuleConfig, (_, dx), implicit::ImplicitFunction, x::AbstractArray{R} + rc::RuleConfig, (_, dx), implicit::ImplicitFunction, x::AbstractArray{R}; kwargs... ) where {R<:Real} - (; forward, conditions, linear_solver) = implicit + forward = implicit.forward + conditions = implicit.conditions + linear_solver = implicit.linear_solver - y = forward(x) + y = forward(x; kwargs...) - conditions_x(x̃) = conditions(x̃, y) - conditions_y(ỹ) = -conditions(x, ỹ) + conditions_x(x̃; kwargs...) = conditions(x̃, y; kwargs...) + conditions_y(ỹ; kwargs...) = -conditions(x, ỹ; kwargs...) - pushforward_A(dỹ) = frule_via_ad(rc, (NoTangent(), dỹ), conditions_y, y)[2] - pushforward_B(dx̃) = frule_via_ad(rc, (NoTangent(), dx̃), conditions_x, x)[2] + pushforward_A(dỹ) = frule_via_ad(rc, (NoTangent(), dỹ), conditions_y, y; kwargs...)[2] + pushforward_B(dx̃) = frule_via_ad(rc, (NoTangent(), dx̃), conditions_x, x; kwargs...)[2] - mul_A!(res, u::AbstractVector) = res .= vec(pushforward_A(reshape(u, size(y)))) - mul_B!(res, v::AbstractVector) = res .= vec(pushforward_B(reshape(v, size(x)))) + mul_A!(res::Vector, u::Vector) = res .= vec(pushforward_A(reshape(u, size(y)))) + mul_B!(res::Vector, v::Vector) = res .= vec(pushforward_B(reshape(v, size(x)))) n, m = length(x), length(y) A = LinearOperator(R, m, m, false, false, mul_A!) @@ -71,34 +79,39 @@ function ChainRulesCore.frule( dx_vec = convert(Vector{R}, vec(unthunk(dx))) b = B * dx_vec dy_vec, stats = linear_solver(A, b) - stats.solved || throw(SolverFailureException("Linear solver failed to converge")) + if !stats.solved + throw(SolverFailureException("Linear solver failed to converge", stats)) + end dy = reshape(dy_vec, size(y)) return y, dy end """ - rrule(rc, implicit, x) + rrule(rc, implicit, x[; kwargs...]) Custom reverse rule for [`ImplicitFunction{F,C,L}`](@ref). We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and setting `Jᵀv = Bᵀu`. +Keyword arguments are given to both `implicit.forward` and `implicit.conditions`. """ function ChainRulesCore.rrule( - rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R} + rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R}; kwargs... ) where {R<:Real} - (; forward, conditions, linear_solver) = implicit + forward = implicit.forward + conditions = implicit.conditions + linear_solver = implicit.linear_solver - y = forward(x) + y = forward(x; kwargs...) - conditions_x(x̃) = conditions(x̃, y) - conditions_y(ỹ) = -conditions(x, ỹ) + conditions_x(x̃; kwargs...) = conditions(x̃, y; kwargs...) + conditions_y(ỹ; kwargs...) = -conditions(x, ỹ; kwargs...) - pullback_Aᵀ = last ∘ rrule_via_ad(rc, conditions_y, y)[2] - pullback_Bᵀ = last ∘ rrule_via_ad(rc, conditions_x, x)[2] + pullback_Aᵀ = rrule_via_ad(rc, conditions_y, y; kwargs...)[2] + pullback_Bᵀ = rrule_via_ad(rc, conditions_x, x; kwargs...)[2] - mul_Aᵀ!(res, u::AbstractVector) = res .= vec(pullback_Aᵀ(reshape(u, size(y)))) - mul_Bᵀ!(res, v::AbstractVector) = res .= vec(pullback_Bᵀ(reshape(v, size(y)))) + mul_Aᵀ!(res::Vector, u::Vector) = res .= vec(pullback_Aᵀ(reshape(u, size(y)))[2]) + mul_Bᵀ!(res::Vector, v::Vector) = res .= vec(pullback_Bᵀ(reshape(v, size(y)))[2]) n, m = length(x), length(y) Aᵀ = LinearOperator(R, m, m, false, false, mul_Aᵀ!) @@ -107,7 +120,9 @@ function ChainRulesCore.rrule( function implicit_pullback(dy) dy_vec = convert(Vector{R}, vec(unthunk(dy))) u, stats = linear_solver(Aᵀ, dy_vec) - stats.solved || throw(SolverFailureException("Linear solver failed to converge")) + if !stats.solved + throw(SolverFailureException("Linear solver failed to converge", stats)) + end dx_vec = Bᵀ * u dx = reshape(dx_vec, size(x)) return (NoTangent(), dx) diff --git a/test/1_unconstrained_optimization.jl b/test/1_unconstrained_optimization.jl index fdd92e5..f71ea45 100644 --- a/test/1_unconstrained_optimization.jl +++ b/test/1_unconstrained_optimization.jl @@ -66,7 +66,11 @@ Zygote.jacobian(implicit, x)[1] # Note that implicit differentiation was necessary here, since our solver alone doesn't support autodiff with `Zygote.jl`. -try; Zygote.jacobian(dumb_identity, x)[1]; catch e; @error e; end +try + Zygote.jacobian(dumb_identity, x)[1] +catch e + e +end # The following tests are not included in the docs. #src diff --git a/test/2_sparse_linear_regression.jl b/test/2_sparse_linear_regression.jl index 6f903da..e1179e5 100644 --- a/test/2_sparse_linear_regression.jl +++ b/test/2_sparse_linear_regression.jl @@ -93,7 +93,11 @@ round.(implicit(data); digits=4) # Note that implicit differentiation is necessary here because the convex solver breaks autodiff. -try; Zygote.jacobian(lasso, data); catch e; @error e; end +try + Zygote.jacobian(lasso, data) +catch e + e +end # Meanwhile, our implicit wrapper makes autodiff work seamlessly. diff --git a/test/3_optimal_transport.jl b/test/3_optimal_transport.jl index 9737a14..7240109 100644 --- a/test/3_optimal_transport.jl +++ b/test/3_optimal_transport.jl @@ -69,21 +69,35 @@ Y = rand(d, m) a = fill(1 / n, n) b = fill(1 / m, m) -C = pairwise(SqEuclidean(), X, Y, dims=2) +C = pairwise(SqEuclidean(), X, Y; dims=2) -ε = 1.; +ε = 1.0; +T = 100; # ## Forward solver # For technical reasons related to optimality checking, our Sinkhorn solver returns ``\hat{u}`` instead of ``\hat{p}_\varepsilon``. -function sinkhorn(C; a=a, b=b, ε=ε) +function sinkhorn(C; a, b, ε, T) K = exp.(.-C ./ ε) u = copy(a) v = copy(b) - for t in 1:100 - u .= a ./ (K * v) - v .= b ./ (K' * u) + for t in 1:T + u = a ./ (K * v) + v = b ./ (K' * u) + end + return u +end + +function sinkhorn_efficient(C; a, b, ε, T) + K = exp.(.-C ./ ε) + u = copy(a) + v = copy(b) + for t in 1:T + mul!(u, K, v) + u .= a ./ u + mul!(v, K', u) + v .= b ./ v end return u end @@ -92,7 +106,7 @@ end # We simply used the fixed point equation $(\text{S})$. -function sinkhorn_fixed_point(C, u; a=a, b=b, ε=ε) +function sinkhorn_fixed_point(C, u; a, b, ε, T=nothing) K = exp.(.-C ./ ε) v = b ./ (K' * u) return u .- a ./ (K * v) @@ -100,34 +114,58 @@ end # We have all we need to build a differentiable Sinkhorn that doesn't require unrolling the fixed point iterations. -implicit = ImplicitFunction(sinkhorn, sinkhorn_fixed_point); +implicit = ImplicitFunction(sinkhorn_efficient, sinkhorn_fixed_point); # ## Testing -u = sinkhorn(C) +u1 = sinkhorn(C; a=a, b=b, ε=ε, T=T) +u2 = implicit(C; a=a, b=b, ε=ε, T=T) +u1 == u2 # First, let us check that the forward pass works correctly and returns a fixed point. -maximum(abs, sinkhorn_fixed_point(C, u)) +all(iszero, sinkhorn_fixed_point(C, u1; a=a, b=b, ε=ε, T=T)) -# Using the implicit function defined above, we can build an autodiff-compatible implementation of `transportation_plan` which does not require backpropagating through the Sinkhorn iterations: +# Using the implicit function defined above, we can build an autodiff-compatible Sinkhorn which does not require backpropagating through the fixed point iterations: -function transportation_plan(C; a=a, b=b, ε=ε) +function transportation_plan_slow(C; a, b, ε, T) K = exp.(.-C ./ ε) - u = implicit(C) + u = sinkhorn(C; a=a, b=b, ε=ε, T=T) v = b ./ (K' * u) - p_vec = vec(u .* K .* v') - return p_vec + p = u .* K .* v' + return p end; +function transportation_plan_fast(C; a, b, ε, T) + K = exp.(.-C ./ ε) + u = implicit(C; a=a, b=b, ε=ε, T=T) + v = b ./ (K' * u) + p = u .* K .* v' + return p +end; + +# What does the transportation plan look like? + +p1 = transportation_plan_slow(C; a=a, b=b, ε=ε, T=T) +p2 = transportation_plan_fast(C; a=a, b=b, ε=ε, T=T) +p1 == p2 + # Let us compare its Jacobian with the one obtained using finite differences. -J = Zygote.jacobian(transportation_plan, C)[1] -J_ref = FiniteDifferences.jacobian(central_fdm(5, 1), transportation_plan, C)[1] -sum(abs, J - J_ref) / prod(size(J)) +J1 = Zygote.jacobian(C -> transportation_plan_slow(C; a=a, b=b, ε=ε, T=T), C)[1] +J2 = Zygote.jacobian(C -> transportation_plan_fast(C; a=a, b=b, ε=ε, T=T), C)[1] +J_ref = FiniteDifferences.jacobian( + central_fdm(5, 1), C -> transportation_plan_slow(C; a=a, b=b, ε=ε, T=T), C +)[1] + +sum(abs, J2 - J_ref) / prod(size(J_ref)) # The following tests are not included in the docs. #src @testset verbose = true "FiniteDifferences" begin #src - @test sum(abs, J - J_ref) / prod(size(J)) < 1e-5 #src + @test u1 == u2 #src + @test all(iszero, sinkhorn_fixed_point(C, u1; a=a, b=b, ε=ε, T=T)) #src + @test p1 == p2 #src + @test sum(abs, J1 - J_ref) / prod(size(J_ref)) < 1e-5 #src + @test sum(abs, J2 - J_ref) / prod(size(J_ref)) < 1e-5 #src end #src diff --git a/test/Manifest.toml b/test/Manifest.toml index c857a34..8195bad 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -2,6 +2,7 @@ julia_version = "1.7.3" manifest_format = "2.0" +project_hash = "252ec556cc576958bb4293ddf4b3affa1825c9d0" [[deps.AMD]] deps = ["Libdl", "LinearAlgebra", "SparseArrays", "Test"] @@ -89,12 +90,6 @@ git-tree-sha1 = "1e315e3f4b0b7ce40feded39c73049692126cf53" uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" version = "0.1.3" -[[deps.CodeTracking]] -deps = ["InteractiveUtils", "UUIDs"] -git-tree-sha1 = "6d4fa04343a7fc9f9cb9cff9558929f3d2752717" -uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" -version = "1.0.9" - [[deps.CodecBzip2]] deps = ["Bzip2_jll", "Libdl", "TranscodingStreams"] git-tree-sha1 = "2e62a725210ce3c3c2e1a3080190e7ca491f18d7" @@ -249,12 +244,6 @@ git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" version = "0.1.1" -[[deps.JET]] -deps = ["InteractiveUtils", "JuliaInterpreter", "LoweredCodeUtils", "MacroTools", "Pkg", "Revise", "Test"] -git-tree-sha1 = "8e78b0c297cfa6cefd579f87232c89bd6ed7a081" -uuid = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" -version = "0.5.16" - [[deps.JLLWrappers]] deps = ["Preferences"] git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1" @@ -267,12 +256,6 @@ git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e" uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" version = "0.21.3" -[[deps.JuliaInterpreter]] -deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"] -git-tree-sha1 = "52617c41d2761cc05ed81fe779804d3b7f14fff7" -uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" -version = "0.9.13" - [[deps.Krylov]] deps = ["LinearAlgebra", "Printf", "SparseArrays"] git-tree-sha1 = "7f0a89bd74c30aa7ff96c4bf1bc884c39663a621" @@ -329,12 +312,6 @@ version = "0.3.15" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" -[[deps.LoweredCodeUtils]] -deps = ["JuliaInterpreter"] -git-tree-sha1 = "dedbebe234e06e1ddad435f5c6f4b85cd8ce55f7" -uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b" -version = "2.2.2" - [[deps.MacroTools]] deps = ["Markdown", "Random"] git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf" @@ -480,12 +457,6 @@ git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" uuid = "ae029012-a4dd-5104-9daa-d747884805df" version = "1.3.0" -[[deps.Revise]] -deps = ["CodeTracking", "Distributed", "FileWatching", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "Pkg", "REPL", "Requires", "UUIDs", "Unicode"] -git-tree-sha1 = "4d4239e93531ac3e7ca7e339f15978d0b5149d03" -uuid = "295af30f-e4ad-537b-8983-00126c2a3abe" -version = "3.3.3" - [[deps.Richardson]] deps = ["LinearAlgebra"] git-tree-sha1 = "e03ca566bec93f8a3aeb059c8ef102f268a38949" diff --git a/test/Project.toml b/test/Project.toml index c9bbdb2..b7e0d33 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,7 +6,6 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Convex = "f65535da-76fb-5f13-bab9-19810c17039a" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" @@ -27,7 +26,6 @@ ComponentArrays = "0.12.2" Convex = "0.15.1" Distances = "0.10.7" FiniteDifferences = "0.12.24" -JET = "0.5.16" Krylov = "0.8.2" LinearOperators = "2.3.2" MathOptInterface = "1.5" diff --git a/test/runtests.jl b/test/runtests.jl index 3159315..67e101c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,21 +2,14 @@ using Aqua using ImplicitDifferentiation -using JET using Random using Test ## Test sets @testset verbose = true "ImplicitDifferentiation.jl" begin - @testset verbose = true "Code quality" begin - @testset verbose = true "JET" begin - jet_report = JET.report_package(ImplicitDifferentiation) - @test string(jet_report) == "No errors detected\n" - end - @testset verbose = true "Aqua" begin - Aqua.test_all(ImplicitDifferentiation) - end + @testset verbose = true "Code quality (Aqua)" begin + Aqua.test_all(ImplicitDifferentiation) end @testset verbose = true "Unconstrained optimization" begin include("1_unconstrained_optimization.jl")