Skip to content

Commit

Permalink
Support second order for Enzyme, take 2 (#285)
Browse files Browse the repository at this point in the history
* Add nested mechanism

* Mode

* Typo

* Tests passing

* Remove additional tests FD+Enzyme

* Add printing

* Remove printing

* More code coverage

* Fix Heisenbug

* Fix Heisenbug

* Debump
  • Loading branch information
gdalle authored May 30, 2024
1 parent c6aaabe commit 39dda67
Show file tree
Hide file tree
Showing 23 changed files with 332 additions and 181 deletions.
4 changes: 3 additions & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.5.2"
version = "0.5.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -88,6 +88,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand All @@ -106,5 +107,6 @@ test = [
"SparseArrays",
"SparseConnectivityTracer",
"SparseMatrixColorings",
"StableRNGs",
"Test",
]
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,40 @@ using Enzyme:
DuplicatedNoNeed,
Forward,
ForwardMode,
Mode,
Reverse,
ReverseWithPrimal,
ReverseSplitWithPrimal,
ReverseMode,
autodiff,
autodiff_deferred,
autodiff_deferred_thunk,
autodiff_thunk,
chunkedonehot,
gradient,
gradient!,
jacobian,
make_zero

const AutoForwardEnzyme = AutoEnzyme{<:ForwardMode}
const AutoForwardOrNothingEnzyme = Union{AutoEnzyme{<:ForwardMode},AutoEnzyme{Nothing}}
const AutoReverseEnzyme = AutoEnzyme{<:ReverseMode}
const AutoReverseOrNothingEnzyme = Union{AutoEnzyme{<:ReverseMode},AutoEnzyme{Nothing}}
struct AutoDeferredEnzyme{M} <: ADTypes.AbstractADType
mode::M
end

ADTypes.mode(backend::AutoDeferredEnzyme) = ADTypes.mode(AutoEnzyme(backend.mode))

DI.backend_package_name(::AutoDeferredEnzyme) = "DeferredEnzyme"

DI.nested(backend::AutoEnzyme) = AutoDeferredEnzyme(backend.mode)

forward_mode(backend::AutoEnzyme{<:ForwardMode}) = backend.mode
forward_mode(::AutoEnzyme{Nothing}) = Forward
const AnyAutoEnzyme{M} = Union{AutoEnzyme{M},AutoDeferredEnzyme{M}}

reverse_mode(backend::AutoEnzyme{<:ReverseMode}) = backend.mode
reverse_mode(::AutoEnzyme{Nothing}) = Reverse
# forward mode if possible
forward_mode(backend::AnyAutoEnzyme{<:Mode}) = backend.mode
forward_mode(::AnyAutoEnzyme{Nothing}) = Forward

# reverse mode if possible
reverse_mode(backend::AnyAutoEnzyme{<:Mode}) = backend.mode
reverse_mode(::AnyAutoEnzyme{Nothing}) = Reverse

DI.check_available(::AutoEnzyme) = true

Expand All @@ -54,12 +66,6 @@ function DI.basis(::AutoEnzyme, a::AbstractArray{T}, i::CartesianIndex) where {T
return b
end

function zero_sametype!(x_target, x)
x_sametype = convert(typeof(x), x_target)
x_sametype .= zero(eltype(x_sametype))
return x_sametype
end

include("forward_onearg.jl")
include("forward_twoarg.jl")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,39 +1,55 @@
## Pushforward

DI.prepare_pushforward(f, ::AutoForwardOrNothingEnzyme, x, dx) = NoPushforwardExtras()
function DI.prepare_pushforward(f, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx)
return NoPushforwardExtras()
end

function DI.value_and_pushforward(
f, backend::AutoForwardOrNothingEnzyme, x, dx, ::NoPushforwardExtras
f, backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx, ::NoPushforwardExtras
)
dx_sametype = convert(typeof(x), dx)
y, new_dy = autodiff(
forward_mode(backend), Const(f), Duplicated, Duplicated(x, dx_sametype)
)
x_and_dx = Duplicated(x, dx_sametype)
y, new_dy = if backend isa AutoDeferredEnzyme
autodiff_deferred(forward_mode(backend), f, Duplicated, x_and_dx)
else
autodiff(forward_mode(backend), Const(f), Duplicated, x_and_dx)
end
return y, new_dy
end

function DI.pushforward(
f, backend::AutoForwardOrNothingEnzyme, x, dx, ::NoPushforwardExtras
f, backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx, ::NoPushforwardExtras
)
dx_sametype = convert(typeof(x), dx)
new_dy = only(
autodiff(
forward_mode(backend), Const(f), DuplicatedNoNeed, Duplicated(x, dx_sametype)
),
)
x_and_dx = Duplicated(x, dx_sametype)
new_dy = if backend isa AutoDeferredEnzyme
only(autodiff_deferred(forward_mode(backend), f, DuplicatedNoNeed, x_and_dx))
else
only(autodiff(forward_mode(backend), Const(f), DuplicatedNoNeed, x_and_dx))
end
return new_dy
end

function DI.value_and_pushforward!(
f, dy, backend::AutoForwardOrNothingEnzyme, x, dx, extras::NoPushforwardExtras
f,
dy,
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
dx,
extras::NoPushforwardExtras,
)
# dy cannot be passed anyway
y, new_dy = DI.value_and_pushforward(f, backend, x, dx, extras)
return y, copyto!(dy, new_dy)
end

function DI.pushforward!(
f, dy, backend::AutoForwardOrNothingEnzyme, x, dx, extras::NoPushforwardExtras
f,
dy,
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
dx,
extras::NoPushforwardExtras,
)
# dy cannot be passed anyway
return copyto!(dy, DI.pushforward(f, backend, x, dx, extras))
Expand All @@ -45,34 +61,34 @@ struct EnzymeForwardGradientExtras{C,O} <: GradientExtras
shadow::O
end

function DI.prepare_gradient(f, ::AutoForwardEnzyme, x)
function DI.prepare_gradient(f, ::AutoEnzyme{<:ForwardMode}, x)
C = pick_chunksize(length(x))
shadow = chunkedonehot(x, Val(C))
return EnzymeForwardGradientExtras{C,typeof(shadow)}(shadow)
end

function DI.gradient(
f, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras{C}
f, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{C}
) where {C}
grad_tup = gradient(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow)
return reshape(collect(grad_tup), size(x))
end

function DI.value_and_gradient(
f, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras
f, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras
)
return f(x), DI.gradient(f, backend, x, extras)
end

function DI.gradient!(
f, grad, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras{C}
f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{C}
) where {C}
grad_tup = gradient(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow)
return copyto!(grad, grad_tup)
end

function DI.value_and_gradient!(
f, grad, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras{C}
f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{C}
) where {C}
grad_tup = gradient(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow)
return f(x), copyto!(grad, grad_tup)
Expand All @@ -84,14 +100,17 @@ struct EnzymeForwardOneArgJacobianExtras{C,O} <: JacobianExtras
shadow::O
end

function DI.prepare_jacobian(f, ::AutoForwardOrNothingEnzyme, x)
function DI.prepare_jacobian(f, ::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x)
C = pick_chunksize(length(x))
shadow = chunkedonehot(x, Val(C))
return EnzymeForwardOneArgJacobianExtras{C,typeof(shadow)}(shadow)
end

function DI.jacobian(
f, backend::AutoForwardOrNothingEnzyme, x, extras::EnzymeForwardOneArgJacobianExtras{C}
f,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
extras::EnzymeForwardOneArgJacobianExtras{C},
) where {C}
jac_wrongshape = jacobian(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow)
nx = length(x)
Expand All @@ -100,15 +119,18 @@ function DI.jacobian(
end

function DI.value_and_jacobian(
f, backend::AutoForwardOrNothingEnzyme, x, extras::EnzymeForwardOneArgJacobianExtras
f,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
extras::EnzymeForwardOneArgJacobianExtras,
)
return f(x), DI.jacobian(f, backend, x, extras)
end

function DI.jacobian!(
f,
jac,
backend::AutoForwardOrNothingEnzyme,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
extras::EnzymeForwardOneArgJacobianExtras,
)
Expand All @@ -118,7 +140,7 @@ end
function DI.value_and_jacobian!(
f,
jac,
backend::AutoForwardOrNothingEnzyme,
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
extras::EnzymeForwardOneArgJacobianExtras,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
## Pushforward

DI.prepare_pushforward(f!, y, ::AutoForwardOrNothingEnzyme, x, dx) = NoPushforwardExtras()
function DI.prepare_pushforward(f!, y, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx)
return NoPushforwardExtras()
end

function DI.value_and_pushforward(
f!, y, backend::AutoForwardOrNothingEnzyme, x, dx, ::NoPushforwardExtras
f!,
y,
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
dx,
::NoPushforwardExtras,
)
dx_sametype = convert(typeof(x), dx)
dy_sametype = zero(y)
autodiff(
forward_mode(backend),
Const(f!),
Const,
Duplicated(y, dy_sametype),
Duplicated(x, dx_sametype),
)
dy_sametype = make_zero(y)
y_and_dy = Duplicated(y, dy_sametype)
x_and_dx = Duplicated(x, dx_sametype)
if backend isa AutoDeferredEnzyme
autodiff_deferred(forward_mode(backend), f!, Const, y_and_dy, x_and_dx)
else
autodiff(forward_mode(backend), Const(f!), Const, y_and_dy, x_and_dx)
end
return y, dy_sametype
end
Loading

0 comments on commit 39dda67

Please sign in to comment.