Skip to content

Commit

Permalink
Merge pull request #19 from gdalle/test_18
Browse files Browse the repository at this point in the history
Prepare for 1.8
  • Loading branch information
gdalle authored Nov 21, 2022
2 parents d1feed2 + 0e40b07 commit 10478ea
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 89 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
fail-fast: false
matrix:
version:
# - '1.6'
- '1.7'
- '1.8'
# - 'nightly'
os:
- ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ImplicitDifferentiation"
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"]
version = "0.2.0"
version = "0.3.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
65 changes: 40 additions & 25 deletions src/implicit_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,39 +30,47 @@ function ImplicitFunction(forward::F, conditions::C) where {F,C}
return ImplicitFunction(forward, conditions, gmres)
end

struct SolverFailureException <: Exception
struct SolverFailureException{S} <: Exception
msg::String
stats::S
end

function Base.show(io::IO, sfe::SolverFailureException)
println(io, "SolverFailureException: $(sfe.msg) \n Solver stats: $(sfe.stats)")
end

"""
implicit(x)
implicit(x[; kwargs...])
Make [`ImplicitFunction{F,C,L}`](@ref) callable by applying `implicit.forward`.
"""
(implicit::ImplicitFunction)(x) = implicit.forward(x)
(implicit::ImplicitFunction)(x; kwargs...) = implicit.forward(x; kwargs...)

"""
frule(rc, (_, dx), implicit, x)
frule(rc, (_, dx), implicit, x[; kwargs...])
Custom forward rule for [`ImplicitFunction{F,C,L}`](@ref).
We compute the Jacobian-vector product `Jv` by solving `Au = Bv` and setting `Jv = u`.
Keyword arguments are given to both `implicit.forward` and `implicit.conditions`.
"""
function ChainRulesCore.frule(
rc::RuleConfig, (_, dx), implicit::ImplicitFunction, x::AbstractArray{R}
rc::RuleConfig, (_, dx), implicit::ImplicitFunction, x::AbstractArray{R}; kwargs...
) where {R<:Real}
(; forward, conditions, linear_solver) = implicit
forward = implicit.forward
conditions = implicit.conditions
linear_solver = implicit.linear_solver

y = forward(x)
y = forward(x; kwargs...)

conditions_x(x̃) = conditions(x̃, y)
conditions_y(ỹ) = -conditions(x, ỹ)
conditions_x(x̃; kwargs...) = conditions(x̃, y; kwargs...)
conditions_y(ỹ; kwargs...) = -conditions(x, ỹ; kwargs...)

pushforward_A(dỹ) = frule_via_ad(rc, (NoTangent(), dỹ), conditions_y, y)[2]
pushforward_B(dx̃) = frule_via_ad(rc, (NoTangent(), dx̃), conditions_x, x)[2]
pushforward_A(dỹ) = frule_via_ad(rc, (NoTangent(), dỹ), conditions_y, y; kwargs...)[2]
pushforward_B(dx̃) = frule_via_ad(rc, (NoTangent(), dx̃), conditions_x, x; kwargs...)[2]

mul_A!(res, u::AbstractVector) = res .= vec(pushforward_A(reshape(u, size(y))))
mul_B!(res, v::AbstractVector) = res .= vec(pushforward_B(reshape(v, size(x))))
mul_A!(res::Vector, u::Vector) = res .= vec(pushforward_A(reshape(u, size(y))))
mul_B!(res::Vector, v::Vector) = res .= vec(pushforward_B(reshape(v, size(x))))

n, m = length(x), length(y)
A = LinearOperator(R, m, m, false, false, mul_A!)
Expand All @@ -71,34 +79,39 @@ function ChainRulesCore.frule(
dx_vec = convert(Vector{R}, vec(unthunk(dx)))
b = B * dx_vec
dy_vec, stats = linear_solver(A, b)
stats.solved || throw(SolverFailureException("Linear solver failed to converge"))
if !stats.solved
throw(SolverFailureException("Linear solver failed to converge", stats))
end
dy = reshape(dy_vec, size(y))

return y, dy
end

"""
rrule(rc, implicit, x)
rrule(rc, implicit, x[; kwargs...])
Custom reverse rule for [`ImplicitFunction{F,C,L}`](@ref).
We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and setting `Jᵀv = Bᵀu`.
Keyword arguments are given to both `implicit.forward` and `implicit.conditions`.
"""
function ChainRulesCore.rrule(
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R}
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R}; kwargs...
) where {R<:Real}
(; forward, conditions, linear_solver) = implicit
forward = implicit.forward
conditions = implicit.conditions
linear_solver = implicit.linear_solver

y = forward(x)
y = forward(x; kwargs...)

conditions_x(x̃) = conditions(x̃, y)
conditions_y(ỹ) = -conditions(x, ỹ)
conditions_x(x̃; kwargs...) = conditions(x̃, y; kwargs...)
conditions_y(ỹ; kwargs...) = -conditions(x, ỹ; kwargs...)

pullback_Aᵀ = last rrule_via_ad(rc, conditions_y, y)[2]
pullback_Bᵀ = last rrule_via_ad(rc, conditions_x, x)[2]
pullback_Aᵀ = rrule_via_ad(rc, conditions_y, y; kwargs...)[2]
pullback_Bᵀ = rrule_via_ad(rc, conditions_x, x; kwargs...)[2]

mul_Aᵀ!(res, u::AbstractVector) = res .= vec(pullback_Aᵀ(reshape(u, size(y))))
mul_Bᵀ!(res, v::AbstractVector) = res .= vec(pullback_Bᵀ(reshape(v, size(y))))
mul_Aᵀ!(res::Vector, u::Vector) = res .= vec(pullback_Aᵀ(reshape(u, size(y)))[2])
mul_Bᵀ!(res::Vector, v::Vector) = res .= vec(pullback_Bᵀ(reshape(v, size(y)))[2])

n, m = length(x), length(y)
Aᵀ = LinearOperator(R, m, m, false, false, mul_Aᵀ!)
Expand All @@ -107,7 +120,9 @@ function ChainRulesCore.rrule(
function implicit_pullback(dy)
dy_vec = convert(Vector{R}, vec(unthunk(dy)))
u, stats = linear_solver(Aᵀ, dy_vec)
stats.solved || throw(SolverFailureException("Linear solver failed to converge"))
if !stats.solved
throw(SolverFailureException("Linear solver failed to converge", stats))
end
dx_vec = Bᵀ * u
dx = reshape(dx_vec, size(x))
return (NoTangent(), dx)
Expand Down
6 changes: 5 additions & 1 deletion test/1_unconstrained_optimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ Zygote.jacobian(implicit, x)[1]

# Note that implicit differentiation was necessary here, since our solver alone doesn't support autodiff with `Zygote.jl`.

try; Zygote.jacobian(dumb_identity, x)[1]; catch e; @error e; end
try
Zygote.jacobian(dumb_identity, x)[1]
catch e
e
end

# The following tests are not included in the docs. #src

Expand Down
6 changes: 5 additions & 1 deletion test/2_sparse_linear_regression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ round.(implicit(data); digits=4)

# Note that implicit differentiation is necessary here because the convex solver breaks autodiff.

try; Zygote.jacobian(lasso, data); catch e; @error e; end
try
Zygote.jacobian(lasso, data)
catch e
e
end

# Meanwhile, our implicit wrapper makes autodiff work seamlessly.

Expand Down
76 changes: 57 additions & 19 deletions test/3_optimal_transport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,35 @@ Y = rand(d, m)

a = fill(1 / n, n)
b = fill(1 / m, m)
C = pairwise(SqEuclidean(), X, Y, dims=2)
C = pairwise(SqEuclidean(), X, Y; dims=2)

ε = 1.;
ε = 1.0;
T = 100;

# ## Forward solver

# For technical reasons related to optimality checking, our Sinkhorn solver returns ``\hat{u}`` instead of ``\hat{p}_\varepsilon``.

function sinkhorn(C; a=a, b=b, ε=ε)
function sinkhorn(C; a, b, ε, T)
K = exp.(.-C ./ ε)
u = copy(a)
v = copy(b)
for t in 1:100
u .= a ./ (K * v)
v .= b ./ (K' * u)
for t in 1:T
u = a ./ (K * v)
v = b ./ (K' * u)
end
return u
end

function sinkhorn_efficient(C; a, b, ε, T)
K = exp.(.-C ./ ε)
u = copy(a)
v = copy(b)
for t in 1:T
mul!(u, K, v)
u .= a ./ u
mul!(v, K', u)
v .= b ./ v
end
return u
end
Expand All @@ -92,42 +106,66 @@ end

# We simply used the fixed point equation $(\text{S})$.

function sinkhorn_fixed_point(C, u; a=a, b=b, ε=ε)
function sinkhorn_fixed_point(C, u; a, b, ε, T=nothing)
K = exp.(.-C ./ ε)
v = b ./ (K' * u)
return u .- a ./ (K * v)
end

# We have all we need to build a differentiable Sinkhorn that doesn't require unrolling the fixed point iterations.

implicit = ImplicitFunction(sinkhorn, sinkhorn_fixed_point);
implicit = ImplicitFunction(sinkhorn_efficient, sinkhorn_fixed_point);

# ## Testing

u = sinkhorn(C)
u1 = sinkhorn(C; a=a, b=b, ε=ε, T=T)
u2 = implicit(C; a=a, b=b, ε=ε, T=T)
u1 == u2

# First, let us check that the forward pass works correctly and returns a fixed point.

maximum(abs, sinkhorn_fixed_point(C, u))
all(iszero, sinkhorn_fixed_point(C, u1; a=a, b=b, ε=ε, T=T))

# Using the implicit function defined above, we can build an autodiff-compatible implementation of `transportation_plan` which does not require backpropagating through the Sinkhorn iterations:
# Using the implicit function defined above, we can build an autodiff-compatible Sinkhorn which does not require backpropagating through the fixed point iterations:

function transportation_plan(C; a=a, b=b, ε=ε)
function transportation_plan_slow(C; a, b, ε, T)
K = exp.(.-C ./ ε)
u = implicit(C)
u = sinkhorn(C; a=a, b=b, ε=ε, T=T)
v = b ./ (K' * u)
p_vec = vec(u .* K .* v')
return p_vec
p = u .* K .* v'
return p
end;

function transportation_plan_fast(C; a, b, ε, T)
K = exp.(.-C ./ ε)
u = implicit(C; a=a, b=b, ε=ε, T=T)
v = b ./ (K' * u)
p = u .* K .* v'
return p
end;

# What does the transportation plan look like?

p1 = transportation_plan_slow(C; a=a, b=b, ε=ε, T=T)
p2 = transportation_plan_fast(C; a=a, b=b, ε=ε, T=T)
p1 == p2

# Let us compare its Jacobian with the one obtained using finite differences.

J = Zygote.jacobian(transportation_plan, C)[1]
J_ref = FiniteDifferences.jacobian(central_fdm(5, 1), transportation_plan, C)[1]
sum(abs, J - J_ref) / prod(size(J))
J1 = Zygote.jacobian(C -> transportation_plan_slow(C; a=a, b=b, ε=ε, T=T), C)[1]
J2 = Zygote.jacobian(C -> transportation_plan_fast(C; a=a, b=b, ε=ε, T=T), C)[1]
J_ref = FiniteDifferences.jacobian(
central_fdm(5, 1), C -> transportation_plan_slow(C; a=a, b=b, ε=ε, T=T), C
)[1]

sum(abs, J2 - J_ref) / prod(size(J_ref))

# The following tests are not included in the docs. #src

@testset verbose = true "FiniteDifferences" begin #src
@test sum(abs, J - J_ref) / prod(size(J)) < 1e-5 #src
@test u1 == u2 #src
@test all(iszero, sinkhorn_fixed_point(C, u1; a=a, b=b, ε=ε, T=T)) #src
@test p1 == p2 #src
@test sum(abs, J1 - J_ref) / prod(size(J_ref)) < 1e-5 #src
@test sum(abs, J2 - J_ref) / prod(size(J_ref)) < 1e-5 #src
end #src
31 changes: 1 addition & 30 deletions test/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

julia_version = "1.7.3"
manifest_format = "2.0"
project_hash = "252ec556cc576958bb4293ddf4b3affa1825c9d0"

[[deps.AMD]]
deps = ["Libdl", "LinearAlgebra", "SparseArrays", "Test"]
Expand Down Expand Up @@ -89,12 +90,6 @@ git-tree-sha1 = "1e315e3f4b0b7ce40feded39c73049692126cf53"
uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
version = "0.1.3"

[[deps.CodeTracking]]
deps = ["InteractiveUtils", "UUIDs"]
git-tree-sha1 = "6d4fa04343a7fc9f9cb9cff9558929f3d2752717"
uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
version = "1.0.9"

[[deps.CodecBzip2]]
deps = ["Bzip2_jll", "Libdl", "TranscodingStreams"]
git-tree-sha1 = "2e62a725210ce3c3c2e1a3080190e7ca491f18d7"
Expand Down Expand Up @@ -249,12 +244,6 @@ git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151"
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
version = "0.1.1"

[[deps.JET]]
deps = ["InteractiveUtils", "JuliaInterpreter", "LoweredCodeUtils", "MacroTools", "Pkg", "Revise", "Test"]
git-tree-sha1 = "8e78b0c297cfa6cefd579f87232c89bd6ed7a081"
uuid = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
version = "0.5.16"

[[deps.JLLWrappers]]
deps = ["Preferences"]
git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1"
Expand All @@ -267,12 +256,6 @@ git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e"
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.21.3"

[[deps.JuliaInterpreter]]
deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"]
git-tree-sha1 = "52617c41d2761cc05ed81fe779804d3b7f14fff7"
uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
version = "0.9.13"

[[deps.Krylov]]
deps = ["LinearAlgebra", "Printf", "SparseArrays"]
git-tree-sha1 = "7f0a89bd74c30aa7ff96c4bf1bc884c39663a621"
Expand Down Expand Up @@ -329,12 +312,6 @@ version = "0.3.15"
[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[deps.LoweredCodeUtils]]
deps = ["JuliaInterpreter"]
git-tree-sha1 = "dedbebe234e06e1ddad435f5c6f4b85cd8ce55f7"
uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b"
version = "2.2.2"

[[deps.MacroTools]]
deps = ["Markdown", "Random"]
git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf"
Expand Down Expand Up @@ -480,12 +457,6 @@ git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.3.0"

[[deps.Revise]]
deps = ["CodeTracking", "Distributed", "FileWatching", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "Pkg", "REPL", "Requires", "UUIDs", "Unicode"]
git-tree-sha1 = "4d4239e93531ac3e7ca7e339f15978d0b5149d03"
uuid = "295af30f-e4ad-537b-8983-00126c2a3abe"
version = "3.3.3"

[[deps.Richardson]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "e03ca566bec93f8a3aeb059c8ef102f268a38949"
Expand Down
2 changes: 0 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Convex = "f65535da-76fb-5f13-bab9-19810c17039a"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
Expand All @@ -27,7 +26,6 @@ ComponentArrays = "0.12.2"
Convex = "0.15.1"
Distances = "0.10.7"
FiniteDifferences = "0.12.24"
JET = "0.5.16"
Krylov = "0.8.2"
LinearOperators = "2.3.2"
MathOptInterface = "1.5"
Expand Down
Loading

0 comments on commit 10478ea

Please sign in to comment.