Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prepare for 1.8 #19

Merged
merged 8 commits into from
Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
fail-fast: false
matrix:
version:
# - '1.6'
- '1.7'
- '1.8'
# - 'nightly'
os:
- ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ImplicitDifferentiation"
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"]
version = "0.2.0"
version = "0.3.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
65 changes: 40 additions & 25 deletions src/implicit_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,39 +30,47 @@ function ImplicitFunction(forward::F, conditions::C) where {F,C}
return ImplicitFunction(forward, conditions, gmres)
end

struct SolverFailureException <: Exception
struct SolverFailureException{S} <: Exception
msg::String
stats::S
end

function Base.show(io::IO, sfe::SolverFailureException)
println(io, "SolverFailureException: $(sfe.msg) \n Solver stats: $(sfe.stats)")
end

"""
implicit(x)
implicit(x[; kwargs...])
Make [`ImplicitFunction{F,C,L}`](@ref) callable by applying `implicit.forward`.
"""
(implicit::ImplicitFunction)(x) = implicit.forward(x)
(implicit::ImplicitFunction)(x; kwargs...) = implicit.forward(x; kwargs...)

"""
frule(rc, (_, dx), implicit, x)
frule(rc, (_, dx), implicit, x[; kwargs...])
Custom forward rule for [`ImplicitFunction{F,C,L}`](@ref).
We compute the Jacobian-vector product `Jv` by solving `Au = Bv` and setting `Jv = u`.
Keyword arguments are given to both `implicit.forward` and `implicit.conditions`.
"""
function ChainRulesCore.frule(
rc::RuleConfig, (_, dx), implicit::ImplicitFunction, x::AbstractArray{R}
rc::RuleConfig, (_, dx), implicit::ImplicitFunction, x::AbstractArray{R}; kwargs...
) where {R<:Real}
(; forward, conditions, linear_solver) = implicit
forward = implicit.forward
conditions = implicit.conditions
linear_solver = implicit.linear_solver

y = forward(x)
y = forward(x; kwargs...)

conditions_x(x̃) = conditions(x̃, y)
conditions_y(ỹ) = -conditions(x, ỹ)
conditions_x(x̃; kwargs...) = conditions(x̃, y; kwargs...)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think conditions should have kwargs in general. What's a use case for this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see it used in the OT example, that's clever. You are using kwargs to indicate parameters we don't want to differentiate wrt? Is this consistent with how ChainRules handles kwargs though?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine even if not, but we should document it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ChainRules does not differentiate wrt kwargs, so I think we're good

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I'd like your input on a type inference problem I found with ChainRulesTestUtils:

  • on lines 107-108, kwargs is free, it refers to an argument of the closure
  • on lines 110-111, kwargs is not free, it refers to the keyword arguments passed to the rrule

This is a workaround I found to avoid type inference errors. The natural thing to do was to pass the keyword arguments from the rrule directly in lines 107-108, but for some reason this gave rise to an inference error in the first round of tests (unconstrained optimization). Any clue why? Is that a problem at all?

Copy link
Collaborator

@mohamed82008 mohamed82008 Sep 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reduced further to:

using Test, ChainRulesCore, ChainRulesTestUtils

struct FF{F, Y}
  f::F
  y::Y
end
(f::FF)(x) = f.f(x, f.y)

g(x, y) = x + y;
f = FF(g, rand(3))
rc = ChainRulesTestUtils.TestConfig()
x = rand(3)

@inferred rrule_via_ad(rc, f, x)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps worth opening an issue there

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably replace that test with @inferred on an rrule_via_ad call using a Zygote.ZygoteRuleConfig() instead.

Copy link
Member Author

@gdalle gdalle Nov 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be in favor of merging as-is until the upstream bug is fixed, and then release 0.3.0

conditions_y(ỹ; kwargs...) = -conditions(x, ỹ; kwargs...)

pushforward_A(dỹ) = frule_via_ad(rc, (NoTangent(), dỹ), conditions_y, y)[2]
pushforward_B(dx̃) = frule_via_ad(rc, (NoTangent(), dx̃), conditions_x, x)[2]
pushforward_A(dỹ) = frule_via_ad(rc, (NoTangent(), dỹ), conditions_y, y; kwargs...)[2]
pushforward_B(dx̃) = frule_via_ad(rc, (NoTangent(), dx̃), conditions_x, x; kwargs...)[2]

mul_A!(res, u::AbstractVector) = res .= vec(pushforward_A(reshape(u, size(y))))
mul_B!(res, v::AbstractVector) = res .= vec(pushforward_B(reshape(v, size(x))))
mul_A!(res::Vector, u::Vector) = res .= vec(pushforward_A(reshape(u, size(y))))
mul_B!(res::Vector, v::Vector) = res .= vec(pushforward_B(reshape(v, size(x))))

n, m = length(x), length(y)
A = LinearOperator(R, m, m, false, false, mul_A!)
Expand All @@ -71,34 +79,39 @@ function ChainRulesCore.frule(
dx_vec = convert(Vector{R}, vec(unthunk(dx)))
b = B * dx_vec
dy_vec, stats = linear_solver(A, b)
stats.solved || throw(SolverFailureException("Linear solver failed to converge"))
if !stats.solved
throw(SolverFailureException("Linear solver failed to converge", stats))
end
dy = reshape(dy_vec, size(y))

return y, dy
end

"""
rrule(rc, implicit, x)
rrule(rc, implicit, x[; kwargs...])
Custom reverse rule for [`ImplicitFunction{F,C,L}`](@ref).
We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and setting `Jᵀv = Bᵀu`.
Keyword arguments are given to both `implicit.forward` and `implicit.conditions`.
"""
function ChainRulesCore.rrule(
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R}
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R}; kwargs...
) where {R<:Real}
(; forward, conditions, linear_solver) = implicit
forward = implicit.forward
conditions = implicit.conditions
linear_solver = implicit.linear_solver

y = forward(x)
y = forward(x; kwargs...)

conditions_x(x̃) = conditions(x̃, y)
conditions_y(ỹ) = -conditions(x, ỹ)
conditions_x(x̃; kwargs...) = conditions(x̃, y; kwargs...)
conditions_y(ỹ; kwargs...) = -conditions(x, ỹ; kwargs...)

pullback_Aᵀ = last rrule_via_ad(rc, conditions_y, y)[2]
pullback_Bᵀ = last rrule_via_ad(rc, conditions_x, x)[2]
pullback_Aᵀ = rrule_via_ad(rc, conditions_y, y; kwargs...)[2]
pullback_Bᵀ = rrule_via_ad(rc, conditions_x, x; kwargs...)[2]

mul_Aᵀ!(res, u::AbstractVector) = res .= vec(pullback_Aᵀ(reshape(u, size(y))))
mul_Bᵀ!(res, v::AbstractVector) = res .= vec(pullback_Bᵀ(reshape(v, size(y))))
mul_Aᵀ!(res::Vector, u::Vector) = res .= vec(pullback_Aᵀ(reshape(u, size(y)))[2])
mul_Bᵀ!(res::Vector, v::Vector) = res .= vec(pullback_Bᵀ(reshape(v, size(y)))[2])

n, m = length(x), length(y)
Aᵀ = LinearOperator(R, m, m, false, false, mul_Aᵀ!)
Expand All @@ -107,7 +120,9 @@ function ChainRulesCore.rrule(
function implicit_pullback(dy)
dy_vec = convert(Vector{R}, vec(unthunk(dy)))
u, stats = linear_solver(Aᵀ, dy_vec)
stats.solved || throw(SolverFailureException("Linear solver failed to converge"))
if !stats.solved
throw(SolverFailureException("Linear solver failed to converge", stats))
end
dx_vec = Bᵀ * u
dx = reshape(dx_vec, size(x))
return (NoTangent(), dx)
Expand Down
6 changes: 5 additions & 1 deletion test/1_unconstrained_optimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ Zygote.jacobian(implicit, x)[1]

# Note that implicit differentiation was necessary here, since our solver alone doesn't support autodiff with `Zygote.jl`.

try; Zygote.jacobian(dumb_identity, x)[1]; catch e; @error e; end
try
Zygote.jacobian(dumb_identity, x)[1]
catch e
e
end

# The following tests are not included in the docs. #src

Expand Down
6 changes: 5 additions & 1 deletion test/2_sparse_linear_regression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ round.(implicit(data); digits=4)

# Note that implicit differentiation is necessary here because the convex solver breaks autodiff.

try; Zygote.jacobian(lasso, data); catch e; @error e; end
try
Zygote.jacobian(lasso, data)
catch e
e
end

# Meanwhile, our implicit wrapper makes autodiff work seamlessly.

Expand Down
76 changes: 57 additions & 19 deletions test/3_optimal_transport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,35 @@ Y = rand(d, m)

a = fill(1 / n, n)
b = fill(1 / m, m)
C = pairwise(SqEuclidean(), X, Y, dims=2)
C = pairwise(SqEuclidean(), X, Y; dims=2)

ε = 1.;
ε = 1.0;
T = 100;

# ## Forward solver

# For technical reasons related to optimality checking, our Sinkhorn solver returns ``\hat{u}`` instead of ``\hat{p}_\varepsilon``.

function sinkhorn(C; a=a, b=b, ε)
function sinkhorn(C; a, b, ε, T)
K = exp.(.-C ./ ε)
u = copy(a)
v = copy(b)
for t in 1:100
u .= a ./ (K * v)
v .= b ./ (K' * u)
for t in 1:T
u = a ./ (K * v)
v = b ./ (K' * u)
end
return u
end

function sinkhorn_efficient(C; a, b, ε, T)
K = exp.(.-C ./ ε)
u = copy(a)
v = copy(b)
for t in 1:T
mul!(u, K, v)
u .= a ./ u
mul!(v, K', u)
v .= b ./ v
end
return u
end
Expand All @@ -92,42 +106,66 @@ end

# We simply used the fixed point equation $(\text{S})$.

function sinkhorn_fixed_point(C, u; a=a, b=b, ε)
function sinkhorn_fixed_point(C, u; a, b, ε, T=nothing)
K = exp.(.-C ./ ε)
v = b ./ (K' * u)
return u .- a ./ (K * v)
end

# We have all we need to build a differentiable Sinkhorn that doesn't require unrolling the fixed point iterations.

implicit = ImplicitFunction(sinkhorn, sinkhorn_fixed_point);
implicit = ImplicitFunction(sinkhorn_efficient, sinkhorn_fixed_point);

# ## Testing

u = sinkhorn(C)
u1 = sinkhorn(C; a=a, b=b, ε=ε, T=T)
u2 = implicit(C; a=a, b=b, ε=ε, T=T)
u1 == u2

# First, let us check that the forward pass works correctly and returns a fixed point.

maximum(abs, sinkhorn_fixed_point(C, u))
all(iszero, sinkhorn_fixed_point(C, u1; a=a, b=b, ε=ε, T=T))

# Using the implicit function defined above, we can build an autodiff-compatible implementation of `transportation_plan` which does not require backpropagating through the Sinkhorn iterations:
# Using the implicit function defined above, we can build an autodiff-compatible Sinkhorn which does not require backpropagating through the fixed point iterations:

function transportation_plan(C; a=a, b=b, ε)
function transportation_plan_slow(C; a, b, ε, T)
K = exp.(.-C ./ ε)
u = implicit(C)
u = sinkhorn(C; a=a, b=b, ε=ε, T=T)
v = b ./ (K' * u)
p_vec = vec(u .* K .* v')
return p_vec
p = u .* K .* v'
return p
end;

function transportation_plan_fast(C; a, b, ε, T)
K = exp.(.-C ./ ε)
u = implicit(C; a=a, b=b, ε=ε, T=T)
v = b ./ (K' * u)
p = u .* K .* v'
return p
end;

# What does the transportation plan look like?

p1 = transportation_plan_slow(C; a=a, b=b, ε=ε, T=T)
p2 = transportation_plan_fast(C; a=a, b=b, ε=ε, T=T)
p1 == p2

# Let us compare its Jacobian with the one obtained using finite differences.

J = Zygote.jacobian(transportation_plan, C)[1]
J_ref = FiniteDifferences.jacobian(central_fdm(5, 1), transportation_plan, C)[1]
sum(abs, J - J_ref) / prod(size(J))
J1 = Zygote.jacobian(C -> transportation_plan_slow(C; a=a, b=b, ε=ε, T=T), C)[1]
J2 = Zygote.jacobian(C -> transportation_plan_fast(C; a=a, b=b, ε=ε, T=T), C)[1]
J_ref = FiniteDifferences.jacobian(
central_fdm(5, 1), C -> transportation_plan_slow(C; a=a, b=b, ε=ε, T=T), C
)[1]

sum(abs, J2 - J_ref) / prod(size(J_ref))

# The following tests are not included in the docs. #src

@testset verbose = true "FiniteDifferences" begin #src
@test sum(abs, J - J_ref) / prod(size(J)) < 1e-5 #src
@test u1 == u2 #src
@test all(iszero, sinkhorn_fixed_point(C, u1; a=a, b=b, ε=ε, T=T)) #src
@test p1 == p2 #src
@test sum(abs, J1 - J_ref) / prod(size(J_ref)) < 1e-5 #src
@test sum(abs, J2 - J_ref) / prod(size(J_ref)) < 1e-5 #src
end #src
31 changes: 1 addition & 30 deletions test/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

julia_version = "1.7.3"
manifest_format = "2.0"
project_hash = "252ec556cc576958bb4293ddf4b3affa1825c9d0"

[[deps.AMD]]
deps = ["Libdl", "LinearAlgebra", "SparseArrays", "Test"]
Expand Down Expand Up @@ -89,12 +90,6 @@ git-tree-sha1 = "1e315e3f4b0b7ce40feded39c73049692126cf53"
uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
version = "0.1.3"

[[deps.CodeTracking]]
deps = ["InteractiveUtils", "UUIDs"]
git-tree-sha1 = "6d4fa04343a7fc9f9cb9cff9558929f3d2752717"
uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
version = "1.0.9"

[[deps.CodecBzip2]]
deps = ["Bzip2_jll", "Libdl", "TranscodingStreams"]
git-tree-sha1 = "2e62a725210ce3c3c2e1a3080190e7ca491f18d7"
Expand Down Expand Up @@ -249,12 +244,6 @@ git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151"
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
version = "0.1.1"

[[deps.JET]]
deps = ["InteractiveUtils", "JuliaInterpreter", "LoweredCodeUtils", "MacroTools", "Pkg", "Revise", "Test"]
git-tree-sha1 = "8e78b0c297cfa6cefd579f87232c89bd6ed7a081"
uuid = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
version = "0.5.16"

[[deps.JLLWrappers]]
deps = ["Preferences"]
git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1"
Expand All @@ -267,12 +256,6 @@ git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e"
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.21.3"

[[deps.JuliaInterpreter]]
deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"]
git-tree-sha1 = "52617c41d2761cc05ed81fe779804d3b7f14fff7"
uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
version = "0.9.13"

[[deps.Krylov]]
deps = ["LinearAlgebra", "Printf", "SparseArrays"]
git-tree-sha1 = "7f0a89bd74c30aa7ff96c4bf1bc884c39663a621"
Expand Down Expand Up @@ -329,12 +312,6 @@ version = "0.3.15"
[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[deps.LoweredCodeUtils]]
deps = ["JuliaInterpreter"]
git-tree-sha1 = "dedbebe234e06e1ddad435f5c6f4b85cd8ce55f7"
uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b"
version = "2.2.2"

[[deps.MacroTools]]
deps = ["Markdown", "Random"]
git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf"
Expand Down Expand Up @@ -480,12 +457,6 @@ git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.3.0"

[[deps.Revise]]
deps = ["CodeTracking", "Distributed", "FileWatching", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "Pkg", "REPL", "Requires", "UUIDs", "Unicode"]
git-tree-sha1 = "4d4239e93531ac3e7ca7e339f15978d0b5149d03"
uuid = "295af30f-e4ad-537b-8983-00126c2a3abe"
version = "3.3.3"

[[deps.Richardson]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "e03ca566bec93f8a3aeb059c8ef102f268a38949"
Expand Down
2 changes: 0 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Convex = "f65535da-76fb-5f13-bab9-19810c17039a"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
Expand All @@ -27,7 +26,6 @@ ComponentArrays = "0.12.2"
Convex = "0.15.1"
Distances = "0.10.7"
FiniteDifferences = "0.12.24"
JET = "0.5.16"
Krylov = "0.8.2"
LinearOperators = "2.3.2"
MathOptInterface = "1.5"
Expand Down
Loading