From 2f4ca045fada127204a50aaaf8e87c09410651ec Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 29 Aug 2022 15:33:43 +0200 Subject: [PATCH 1/8] Fix JET test in v1.8 --- .github/workflows/CI.yml | 3 ++- src/implicit_function.jl | 16 ++++++++++------ test/1_unconstrained_optimization.jl | 6 +++++- test/2_sparse_linear_regression.jl | 6 +++++- 4 files changed, 22 insertions(+), 9 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 89a6c22..5a9926a 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -18,8 +18,9 @@ jobs: fail-fast: false matrix: version: - # - '1.6' + - '1.6' - '1.7' + - '1.8' # - 'nightly' os: - ubuntu-latest diff --git a/src/implicit_function.jl b/src/implicit_function.jl index 7d9061b..e127ad0 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -51,7 +51,9 @@ We compute the Jacobian-vector product `Jv` by solving `Au = Bv` and setting `Jv function ChainRulesCore.frule( rc::RuleConfig, (_, dx), implicit::ImplicitFunction, x::AbstractArray{R} ) where {R<:Real} - (; forward, conditions, linear_solver) = implicit + forward = implicit.forward + conditions = implicit.conditions + linear_solver = implicit.linear_solver y = forward(x) @@ -61,8 +63,8 @@ function ChainRulesCore.frule( pushforward_A(dỹ) = frule_via_ad(rc, (NoTangent(), dỹ), conditions_y, y)[2] pushforward_B(dx̃) = frule_via_ad(rc, (NoTangent(), dx̃), conditions_x, x)[2] - mul_A!(res, u::AbstractVector) = res .= vec(pushforward_A(reshape(u, size(y)))) - mul_B!(res, v::AbstractVector) = res .= vec(pushforward_B(reshape(v, size(x)))) + mul_A!(res::Vector, u::Vector) = res .= vec(pushforward_A(reshape(u, size(y)))) + mul_B!(res::Vector, v::Vector) = res .= vec(pushforward_B(reshape(v, size(x)))) n, m = length(x), length(y) A = LinearOperator(R, m, m, false, false, mul_A!) @@ -87,7 +89,9 @@ We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and settin function ChainRulesCore.rrule( rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R} ) where {R<:Real} - (; forward, conditions, linear_solver) = implicit + forward = implicit.forward + conditions = implicit.conditions + linear_solver = implicit.linear_solver y = forward(x) @@ -97,8 +101,8 @@ function ChainRulesCore.rrule( pullback_Aᵀ = last ∘ rrule_via_ad(rc, conditions_y, y)[2] pullback_Bᵀ = last ∘ rrule_via_ad(rc, conditions_x, x)[2] - mul_Aᵀ!(res, u::AbstractVector) = res .= vec(pullback_Aᵀ(reshape(u, size(y)))) - mul_Bᵀ!(res, v::AbstractVector) = res .= vec(pullback_Bᵀ(reshape(v, size(y)))) + mul_Aᵀ!(res::Vector, u::Vector) = res .= vec(pullback_Aᵀ(reshape(u, size(y)))) + mul_Bᵀ!(res::Vector, v::Vector) = res .= vec(pullback_Bᵀ(reshape(v, size(y)))) n, m = length(x), length(y) Aᵀ = LinearOperator(R, m, m, false, false, mul_Aᵀ!) diff --git a/test/1_unconstrained_optimization.jl b/test/1_unconstrained_optimization.jl index fdd92e5..f71ea45 100644 --- a/test/1_unconstrained_optimization.jl +++ b/test/1_unconstrained_optimization.jl @@ -66,7 +66,11 @@ Zygote.jacobian(implicit, x)[1] # Note that implicit differentiation was necessary here, since our solver alone doesn't support autodiff with `Zygote.jl`. -try; Zygote.jacobian(dumb_identity, x)[1]; catch e; @error e; end +try + Zygote.jacobian(dumb_identity, x)[1] +catch e + e +end # The following tests are not included in the docs. #src diff --git a/test/2_sparse_linear_regression.jl b/test/2_sparse_linear_regression.jl index 6f903da..e1179e5 100644 --- a/test/2_sparse_linear_regression.jl +++ b/test/2_sparse_linear_regression.jl @@ -93,7 +93,11 @@ round.(implicit(data); digits=4) # Note that implicit differentiation is necessary here because the convex solver breaks autodiff. -try; Zygote.jacobian(lasso, data); catch e; @error e; end +try + Zygote.jacobian(lasso, data) +catch e + e +end # Meanwhile, our implicit wrapper makes autodiff work seamlessly. From dd8e609012c9a6c355605d63f57bcca7745abf65 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 29 Aug 2022 15:39:51 +0200 Subject: [PATCH 2/8] Fix compat with 1.6 --- Project.toml | 4 ++-- src/implicit_function.jl | 10 +++++----- test/Manifest.toml | 31 +------------------------------ test/Project.toml | 2 -- test/runtests.jl | 11 ++--------- 5 files changed, 10 insertions(+), 48 deletions(-) diff --git a/Project.toml b/Project.toml index 961d2b0..dd0b8e4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ImplicitDifferentiation" uuid = "57b37032-215b-411a-8a7c-41a003a55207" authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"] -version = "0.2.0" +version = "0.3.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -12,4 +12,4 @@ LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" ChainRulesCore = "1.14" Krylov = "0.8.1" LinearOperators = "2.2.3" -julia = "1.7" +julia = "1.6" diff --git a/src/implicit_function.jl b/src/implicit_function.jl index e127ad0..464d8af 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -39,7 +39,7 @@ end Make [`ImplicitFunction{F,C,L}`](@ref) callable by applying `implicit.forward`. """ -(implicit::ImplicitFunction)(x) = implicit.forward(x) +(implicit::ImplicitFunction)(x; kwargs...) = implicit.forward(x; kwargs...) """ frule(rc, (_, dx), implicit, x) @@ -49,13 +49,13 @@ Custom forward rule for [`ImplicitFunction{F,C,L}`](@ref). We compute the Jacobian-vector product `Jv` by solving `Au = Bv` and setting `Jv = u`. """ function ChainRulesCore.frule( - rc::RuleConfig, (_, dx), implicit::ImplicitFunction, x::AbstractArray{R} + rc::RuleConfig, (_, dx), implicit::ImplicitFunction, x::AbstractArray{R}; kwargs... ) where {R<:Real} forward = implicit.forward conditions = implicit.conditions linear_solver = implicit.linear_solver - y = forward(x) + y = forward(x; kwargs...) conditions_x(x̃) = conditions(x̃, y) conditions_y(ỹ) = -conditions(x, ỹ) @@ -87,13 +87,13 @@ Custom reverse rule for [`ImplicitFunction{F,C,L}`](@ref). We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and setting `Jᵀv = Bᵀu`. """ function ChainRulesCore.rrule( - rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R} + rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R}; kwargs... ) where {R<:Real} forward = implicit.forward conditions = implicit.conditions linear_solver = implicit.linear_solver - y = forward(x) + y = forward(x; kwargs...) conditions_x(x̃) = conditions(x̃, y) conditions_y(ỹ) = -conditions(x, ỹ) diff --git a/test/Manifest.toml b/test/Manifest.toml index c857a34..8195bad 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -2,6 +2,7 @@ julia_version = "1.7.3" manifest_format = "2.0" +project_hash = "252ec556cc576958bb4293ddf4b3affa1825c9d0" [[deps.AMD]] deps = ["Libdl", "LinearAlgebra", "SparseArrays", "Test"] @@ -89,12 +90,6 @@ git-tree-sha1 = "1e315e3f4b0b7ce40feded39c73049692126cf53" uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" version = "0.1.3" -[[deps.CodeTracking]] -deps = ["InteractiveUtils", "UUIDs"] -git-tree-sha1 = "6d4fa04343a7fc9f9cb9cff9558929f3d2752717" -uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" -version = "1.0.9" - [[deps.CodecBzip2]] deps = ["Bzip2_jll", "Libdl", "TranscodingStreams"] git-tree-sha1 = "2e62a725210ce3c3c2e1a3080190e7ca491f18d7" @@ -249,12 +244,6 @@ git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" version = "0.1.1" -[[deps.JET]] -deps = ["InteractiveUtils", "JuliaInterpreter", "LoweredCodeUtils", "MacroTools", "Pkg", "Revise", "Test"] -git-tree-sha1 = "8e78b0c297cfa6cefd579f87232c89bd6ed7a081" -uuid = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" -version = "0.5.16" - [[deps.JLLWrappers]] deps = ["Preferences"] git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1" @@ -267,12 +256,6 @@ git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e" uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" version = "0.21.3" -[[deps.JuliaInterpreter]] -deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"] -git-tree-sha1 = "52617c41d2761cc05ed81fe779804d3b7f14fff7" -uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" -version = "0.9.13" - [[deps.Krylov]] deps = ["LinearAlgebra", "Printf", "SparseArrays"] git-tree-sha1 = "7f0a89bd74c30aa7ff96c4bf1bc884c39663a621" @@ -329,12 +312,6 @@ version = "0.3.15" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" -[[deps.LoweredCodeUtils]] -deps = ["JuliaInterpreter"] -git-tree-sha1 = "dedbebe234e06e1ddad435f5c6f4b85cd8ce55f7" -uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b" -version = "2.2.2" - [[deps.MacroTools]] deps = ["Markdown", "Random"] git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf" @@ -480,12 +457,6 @@ git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" uuid = "ae029012-a4dd-5104-9daa-d747884805df" version = "1.3.0" -[[deps.Revise]] -deps = ["CodeTracking", "Distributed", "FileWatching", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "Pkg", "REPL", "Requires", "UUIDs", "Unicode"] -git-tree-sha1 = "4d4239e93531ac3e7ca7e339f15978d0b5149d03" -uuid = "295af30f-e4ad-537b-8983-00126c2a3abe" -version = "3.3.3" - [[deps.Richardson]] deps = ["LinearAlgebra"] git-tree-sha1 = "e03ca566bec93f8a3aeb059c8ef102f268a38949" diff --git a/test/Project.toml b/test/Project.toml index c9bbdb2..b7e0d33 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,7 +6,6 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Convex = "f65535da-76fb-5f13-bab9-19810c17039a" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" @@ -27,7 +26,6 @@ ComponentArrays = "0.12.2" Convex = "0.15.1" Distances = "0.10.7" FiniteDifferences = "0.12.24" -JET = "0.5.16" Krylov = "0.8.2" LinearOperators = "2.3.2" MathOptInterface = "1.5" diff --git a/test/runtests.jl b/test/runtests.jl index 3159315..67e101c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,21 +2,14 @@ using Aqua using ImplicitDifferentiation -using JET using Random using Test ## Test sets @testset verbose = true "ImplicitDifferentiation.jl" begin - @testset verbose = true "Code quality" begin - @testset verbose = true "JET" begin - jet_report = JET.report_package(ImplicitDifferentiation) - @test string(jet_report) == "No errors detected\n" - end - @testset verbose = true "Aqua" begin - Aqua.test_all(ImplicitDifferentiation) - end + @testset verbose = true "Code quality (Aqua)" begin + Aqua.test_all(ImplicitDifferentiation) end @testset verbose = true "Unconstrained optimization" begin include("1_unconstrained_optimization.jl") From 1d2e039eb2e1e11c558979dbc1c02fb31347cc67 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 29 Aug 2022 15:45:50 +0200 Subject: [PATCH 3/8] Add documentation for kwargs --- src/implicit_function.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/implicit_function.jl b/src/implicit_function.jl index 464d8af..1bacda9 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -35,18 +35,19 @@ struct SolverFailureException <: Exception end """ - implicit(x) + implicit(x[; kwargs...]) Make [`ImplicitFunction{F,C,L}`](@ref) callable by applying `implicit.forward`. """ (implicit::ImplicitFunction)(x; kwargs...) = implicit.forward(x; kwargs...) """ - frule(rc, (_, dx), implicit, x) + frule(rc, (_, dx), implicit, x[; kwargs...]) Custom forward rule for [`ImplicitFunction{F,C,L}`](@ref). We compute the Jacobian-vector product `Jv` by solving `Au = Bv` and setting `Jv = u`. +Keyword arguments are given to `implicit.forward`, not to `implicit.conditions`. """ function ChainRulesCore.frule( rc::RuleConfig, (_, dx), implicit::ImplicitFunction, x::AbstractArray{R}; kwargs... @@ -80,11 +81,12 @@ function ChainRulesCore.frule( end """ - rrule(rc, implicit, x) + rrule(rc, implicit, x[; kwargs...]) Custom reverse rule for [`ImplicitFunction{F,C,L}`](@ref). We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and setting `Jᵀv = Bᵀu`. +Keyword arguments are given to `implicit.forward`, not to `implicit.conditions`. """ function ChainRulesCore.rrule( rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R}; kwargs... From 15534b28767207bdcb71d21dc8cff64a0168ea70 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 29 Aug 2022 17:00:51 +0200 Subject: [PATCH 4/8] Compat back to 1.7 --- .github/workflows/CI.yml | 1 - Project.toml | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 5a9926a..071ac47 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -18,7 +18,6 @@ jobs: fail-fast: false matrix: version: - - '1.6' - '1.7' - '1.8' # - 'nightly' diff --git a/Project.toml b/Project.toml index dd0b8e4..7840aea 100644 --- a/Project.toml +++ b/Project.toml @@ -12,4 +12,4 @@ LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" ChainRulesCore = "1.14" Krylov = "0.8.1" LinearOperators = "2.2.3" -julia = "1.6" +julia = "1.7" From 25d423cacc0c518c0937717278cf370dc5282525 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 29 Aug 2022 17:27:41 +0200 Subject: [PATCH 5/8] Fix optimal transport kwargs --- src/implicit_function.jl | 20 +++++----- test/3_optimal_transport.jl | 76 +++++++++++++++++++++++++++---------- 2 files changed, 67 insertions(+), 29 deletions(-) diff --git a/src/implicit_function.jl b/src/implicit_function.jl index 1bacda9..a2d6cb0 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -58,11 +58,11 @@ function ChainRulesCore.frule( y = forward(x; kwargs...) - conditions_x(x̃) = conditions(x̃, y) - conditions_y(ỹ) = -conditions(x, ỹ) + conditions_x(x̃; kwargs...) = conditions(x̃, y; kwargs...) + conditions_y(ỹ; kwargs...) = -conditions(x, ỹ; kwargs...) - pushforward_A(dỹ) = frule_via_ad(rc, (NoTangent(), dỹ), conditions_y, y)[2] - pushforward_B(dx̃) = frule_via_ad(rc, (NoTangent(), dx̃), conditions_x, x)[2] + pushforward_A(dỹ) = frule_via_ad(rc, (NoTangent(), dỹ), conditions_y, y; kwargs...)[2] + pushforward_B(dx̃) = frule_via_ad(rc, (NoTangent(), dx̃), conditions_x, x; kwargs...)[2] mul_A!(res::Vector, u::Vector) = res .= vec(pushforward_A(reshape(u, size(y)))) mul_B!(res::Vector, v::Vector) = res .= vec(pushforward_B(reshape(v, size(x)))) @@ -97,14 +97,14 @@ function ChainRulesCore.rrule( y = forward(x; kwargs...) - conditions_x(x̃) = conditions(x̃, y) - conditions_y(ỹ) = -conditions(x, ỹ) + conditions_x(x̃; kwargs...) = conditions(x̃, y; kwargs...) + conditions_y(ỹ; kwargs...) = -conditions(x, ỹ; kwargs...) - pullback_Aᵀ = last ∘ rrule_via_ad(rc, conditions_y, y)[2] - pullback_Bᵀ = last ∘ rrule_via_ad(rc, conditions_x, x)[2] + pullback_Aᵀ = rrule_via_ad(rc, conditions_y, y; kwargs...)[2] + pullback_Bᵀ = rrule_via_ad(rc, conditions_x, x; kwargs...)[2] - mul_Aᵀ!(res::Vector, u::Vector) = res .= vec(pullback_Aᵀ(reshape(u, size(y)))) - mul_Bᵀ!(res::Vector, v::Vector) = res .= vec(pullback_Bᵀ(reshape(v, size(y)))) + mul_Aᵀ!(res::Vector, u::Vector) = res .= vec(pullback_Aᵀ(reshape(u, size(y)))[2]) + mul_Bᵀ!(res::Vector, v::Vector) = res .= vec(pullback_Bᵀ(reshape(v, size(y)))[2]) n, m = length(x), length(y) Aᵀ = LinearOperator(R, m, m, false, false, mul_Aᵀ!) diff --git a/test/3_optimal_transport.jl b/test/3_optimal_transport.jl index 9737a14..7240109 100644 --- a/test/3_optimal_transport.jl +++ b/test/3_optimal_transport.jl @@ -69,21 +69,35 @@ Y = rand(d, m) a = fill(1 / n, n) b = fill(1 / m, m) -C = pairwise(SqEuclidean(), X, Y, dims=2) +C = pairwise(SqEuclidean(), X, Y; dims=2) -ε = 1.; +ε = 1.0; +T = 100; # ## Forward solver # For technical reasons related to optimality checking, our Sinkhorn solver returns ``\hat{u}`` instead of ``\hat{p}_\varepsilon``. -function sinkhorn(C; a=a, b=b, ε=ε) +function sinkhorn(C; a, b, ε, T) K = exp.(.-C ./ ε) u = copy(a) v = copy(b) - for t in 1:100 - u .= a ./ (K * v) - v .= b ./ (K' * u) + for t in 1:T + u = a ./ (K * v) + v = b ./ (K' * u) + end + return u +end + +function sinkhorn_efficient(C; a, b, ε, T) + K = exp.(.-C ./ ε) + u = copy(a) + v = copy(b) + for t in 1:T + mul!(u, K, v) + u .= a ./ u + mul!(v, K', u) + v .= b ./ v end return u end @@ -92,7 +106,7 @@ end # We simply used the fixed point equation $(\text{S})$. -function sinkhorn_fixed_point(C, u; a=a, b=b, ε=ε) +function sinkhorn_fixed_point(C, u; a, b, ε, T=nothing) K = exp.(.-C ./ ε) v = b ./ (K' * u) return u .- a ./ (K * v) @@ -100,34 +114,58 @@ end # We have all we need to build a differentiable Sinkhorn that doesn't require unrolling the fixed point iterations. -implicit = ImplicitFunction(sinkhorn, sinkhorn_fixed_point); +implicit = ImplicitFunction(sinkhorn_efficient, sinkhorn_fixed_point); # ## Testing -u = sinkhorn(C) +u1 = sinkhorn(C; a=a, b=b, ε=ε, T=T) +u2 = implicit(C; a=a, b=b, ε=ε, T=T) +u1 == u2 # First, let us check that the forward pass works correctly and returns a fixed point. -maximum(abs, sinkhorn_fixed_point(C, u)) +all(iszero, sinkhorn_fixed_point(C, u1; a=a, b=b, ε=ε, T=T)) -# Using the implicit function defined above, we can build an autodiff-compatible implementation of `transportation_plan` which does not require backpropagating through the Sinkhorn iterations: +# Using the implicit function defined above, we can build an autodiff-compatible Sinkhorn which does not require backpropagating through the fixed point iterations: -function transportation_plan(C; a=a, b=b, ε=ε) +function transportation_plan_slow(C; a, b, ε, T) K = exp.(.-C ./ ε) - u = implicit(C) + u = sinkhorn(C; a=a, b=b, ε=ε, T=T) v = b ./ (K' * u) - p_vec = vec(u .* K .* v') - return p_vec + p = u .* K .* v' + return p end; +function transportation_plan_fast(C; a, b, ε, T) + K = exp.(.-C ./ ε) + u = implicit(C; a=a, b=b, ε=ε, T=T) + v = b ./ (K' * u) + p = u .* K .* v' + return p +end; + +# What does the transportation plan look like? + +p1 = transportation_plan_slow(C; a=a, b=b, ε=ε, T=T) +p2 = transportation_plan_fast(C; a=a, b=b, ε=ε, T=T) +p1 == p2 + # Let us compare its Jacobian with the one obtained using finite differences. -J = Zygote.jacobian(transportation_plan, C)[1] -J_ref = FiniteDifferences.jacobian(central_fdm(5, 1), transportation_plan, C)[1] -sum(abs, J - J_ref) / prod(size(J)) +J1 = Zygote.jacobian(C -> transportation_plan_slow(C; a=a, b=b, ε=ε, T=T), C)[1] +J2 = Zygote.jacobian(C -> transportation_plan_fast(C; a=a, b=b, ε=ε, T=T), C)[1] +J_ref = FiniteDifferences.jacobian( + central_fdm(5, 1), C -> transportation_plan_slow(C; a=a, b=b, ε=ε, T=T), C +)[1] + +sum(abs, J2 - J_ref) / prod(size(J_ref)) # The following tests are not included in the docs. #src @testset verbose = true "FiniteDifferences" begin #src - @test sum(abs, J - J_ref) / prod(size(J)) < 1e-5 #src + @test u1 == u2 #src + @test all(iszero, sinkhorn_fixed_point(C, u1; a=a, b=b, ε=ε, T=T)) #src + @test p1 == p2 #src + @test sum(abs, J1 - J_ref) / prod(size(J_ref)) < 1e-5 #src + @test sum(abs, J2 - J_ref) / prod(size(J_ref)) < 1e-5 #src end #src From 4426ebc816f5505551616943ea3070299ce30543 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 29 Aug 2022 17:38:02 +0200 Subject: [PATCH 6/8] Better error message --- src/implicit_function.jl | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/implicit_function.jl b/src/implicit_function.jl index a2d6cb0..0818971 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -30,8 +30,13 @@ function ImplicitFunction(forward::F, conditions::C) where {F,C} return ImplicitFunction(forward, conditions, gmres) end -struct SolverFailureException <: Exception +struct SolverFailureException{S} <: Exception msg::String + stats::S +end + +function Base.show(io::IO, sfe::SolverFailureException) + println(io, "SolverFailureException: $(sfe.msg) \n Solver stats: $(sfe.stats)") end """ @@ -74,7 +79,9 @@ function ChainRulesCore.frule( dx_vec = convert(Vector{R}, vec(unthunk(dx))) b = B * dx_vec dy_vec, stats = linear_solver(A, b) - stats.solved || throw(SolverFailureException("Linear solver failed to converge")) + if !stats.solved + throw(SolverFailureException("Linear solver failed to converge", stats)) + end dy = reshape(dy_vec, size(y)) return y, dy @@ -113,7 +120,9 @@ function ChainRulesCore.rrule( function implicit_pullback(dy) dy_vec = convert(Vector{R}, vec(unthunk(dy))) u, stats = linear_solver(Aᵀ, dy_vec) - stats.solved || throw(SolverFailureException("Linear solver failed to converge")) + if !stats.solved | true + throw(SolverFailureException("Linear solver failed to converge", stats)) + end dx_vec = Bᵀ * u dx = reshape(dx_vec, size(x)) return (NoTangent(), dx) From fddd6724d74050e55b011a1363beeccb41666420 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 29 Aug 2022 17:38:28 +0200 Subject: [PATCH 7/8] Remove if true --- src/implicit_function.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/implicit_function.jl b/src/implicit_function.jl index 0818971..e2db0cb 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -120,7 +120,7 @@ function ChainRulesCore.rrule( function implicit_pullback(dy) dy_vec = convert(Vector{R}, vec(unthunk(dy))) u, stats = linear_solver(Aᵀ, dy_vec) - if !stats.solved | true + if !stats.solved throw(SolverFailureException("Linear solver failed to converge", stats)) end dx_vec = Bᵀ * u From 0e40b07f70485e3f59f832f3ca8f58b31bce2b2d Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 29 Aug 2022 17:47:25 +0200 Subject: [PATCH 8/8] Fix docs --- src/implicit_function.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/implicit_function.jl b/src/implicit_function.jl index e2db0cb..95177f3 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -52,7 +52,7 @@ Make [`ImplicitFunction{F,C,L}`](@ref) callable by applying `implicit.forward`. Custom forward rule for [`ImplicitFunction{F,C,L}`](@ref). We compute the Jacobian-vector product `Jv` by solving `Au = Bv` and setting `Jv = u`. -Keyword arguments are given to `implicit.forward`, not to `implicit.conditions`. +Keyword arguments are given to both `implicit.forward` and `implicit.conditions`. """ function ChainRulesCore.frule( rc::RuleConfig, (_, dx), implicit::ImplicitFunction, x::AbstractArray{R}; kwargs... @@ -93,7 +93,7 @@ end Custom reverse rule for [`ImplicitFunction{F,C,L}`](@ref). We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and setting `Jᵀv = Bᵀu`. -Keyword arguments are given to `implicit.forward`, not to `implicit.conditions`. +Keyword arguments are given to both `implicit.forward` and `implicit.conditions`. """ function ChainRulesCore.rrule( rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R}; kwargs...