Skip to content

Commit

Permalink
Support static arrays with reverse Enzyme (#585)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Oct 16, 2024
1 parent 94f9bc5 commit d4b17c1
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 78 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.13"
version = "0.6.14"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,40 +61,6 @@ end

### Out-of-place

function DI.value_and_pullback(
f::F,
::NoPullbackPrep,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
x::Number,
ty::NTuple{1},
contexts::Vararg{Context,C},
) where {F,C}
f_and_df = force_annotation(get_f_and_df(f, backend))
mode = reverse_split_withprimal(backend)
RA = eltype(ty) <: Number ? Active : Duplicated
dinputs, result = seeded_autodiff_thunk(
mode, only(ty), f_and_df, RA, Active(x), map(translate, contexts)...
)
return result, (first(dinputs),)
end

function DI.value_and_pullback(
f::F,
::NoPullbackPrep,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
x::Number,
ty::NTuple{B},
contexts::Vararg{Context,C},
) where {F,B,C}
f_and_df = force_annotation(get_f_and_df(f, backend, Val(B)))
mode = reverse_split_withprimal(backend)
RA = eltype(ty) <: Number ? Active : BatchDuplicated
dinputs, result = batch_seeded_autodiff_thunk(
mode, ty, f_and_df, RA, Active(x), map(translate, contexts)...
)
return result, values(first(dinputs))
end

function DI.value_and_pullback(
f::F,
::NoPullbackPrep,
Expand All @@ -105,12 +71,18 @@ function DI.value_and_pullback(
) where {F,C}
f_and_df = force_annotation(get_f_and_df(f, backend))
mode = reverse_split_withprimal(backend)
RA = eltype(ty) <: Number ? Active : Duplicated
IA = guess_activity(typeof(x), mode)
RA = guess_activity(eltype(ty), mode)
dx = make_zero(x)
_, result = seeded_autodiff_thunk(
mode, only(ty), f_and_df, RA, Duplicated(x, dx), map(translate, contexts)...
dinputs, result = seeded_autodiff_thunk(
mode, only(ty), f_and_df, RA, annotate(IA, x, dx), map(translate, contexts)...
)
return result, (dx,)
new_dx = first(dinputs)
if isnothing(new_dx)
return result, (dx,)
else
return result, (new_dx,)
end
end

function DI.value_and_pullback(
Expand All @@ -123,12 +95,18 @@ function DI.value_and_pullback(
) where {F,B,C}
f_and_df = force_annotation(get_f_and_df(f, backend, Val(B)))
mode = reverse_split_withprimal(backend)
RA = eltype(ty) <: Number ? Active : BatchDuplicated
IA = batchify_activity(guess_activity(typeof(x), mode), Val(B))
RA = batchify_activity(guess_activity(eltype(ty), mode), Val(B))
tx = ntuple(_ -> make_zero(x), Val(B))
_, result = batch_seeded_autodiff_thunk(
mode, ty, f_and_df, RA, BatchDuplicated(x, tx), map(translate, contexts)...
dinputs, result = batch_seeded_autodiff_thunk(
mode, ty, f_and_df, RA, annotate(IA, x, tx), map(translate, contexts)...
)
return result, tx
new_tx = values(first(dinputs))
if isnothing(new_tx)
return result, tx
else
return result, new_tx
end
end

function DI.pullback(
Expand All @@ -155,7 +133,7 @@ function DI.value_and_pullback!(
) where {F,C}
f_and_df = force_annotation(get_f_and_df(f, backend))
mode = reverse_split_withprimal(backend)
RA = eltype(ty) <: Number ? Active : Duplicated
RA = guess_activity(eltype(ty), mode)
dx_righttype = convert(typeof(x), only(tx))
make_zero!(dx_righttype)
_, result = seeded_autodiff_thunk(
Expand All @@ -181,7 +159,7 @@ function DI.value_and_pullback!(
) where {F,B,C}
f_and_df = force_annotation(get_f_and_df(f, backend, Val(B)))
mode = reverse_split_withprimal(backend)
RA = eltype(ty) <: Number ? Active : BatchDuplicated
RA = batchify_activity(guess_activity(eltype(ty), mode), Val(B))
tx_righttype = map(Fix1(convert, typeof(x)), tx)
make_zero!(tx_righttype)
_, result = batch_seeded_autodiff_thunk(
Expand Down Expand Up @@ -213,29 +191,39 @@ end
### Without preparation

function DI.gradient(
f::F,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
x,
contexts::Vararg{Context,C},
f::F, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{Context,C}
) where {F,C}
f_and_df = get_f_and_df(f, backend)
ders = gradient(reverse_noprimal(backend), f_and_df, x, map(translate, contexts)...)
grad = first(ders)
return grad
mode = reverse_noprimal(backend)
IA = guess_activity(typeof(x), mode)
grad = make_zero(x)
dinputs = only(
autodiff(mode, f_and_df, Active, annotate(IA, x, grad), map(translate, contexts)...)
)
new_grad = first(dinputs)
if isnothing(new_grad)
return grad
else
return new_grad
end
end

function DI.value_and_gradient(
f::F,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
x,
contexts::Vararg{Context,C},
f::F, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{Context,C}
) where {F,C}
f_and_df = get_f_and_df(f, backend)
ders, y = gradient(
reverse_withprimal(backend), f_and_df, x, map(translate, contexts)...
mode = reverse_withprimal(backend)
IA = guess_activity(typeof(x), mode)
grad = make_zero(x)
dinputs, result = autodiff(
mode, f_and_df, Active, annotate(IA, x, grad), map(translate, contexts)...
)
grad = first(ders)
return y, grad
new_grad = first(dinputs)
if isnothing(new_grad)
return result, grad
else
return result, new_grad
end
end

### With preparation
Expand All @@ -245,10 +233,7 @@ struct EnzymeGradientPrep{G} <: GradientPrep
end

function DI.prepare_gradient(
f::F,
::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
x,
contexts::Vararg{Context,C},
f::F, ::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{Context,C}
) where {F,C}
grad_righttype = make_zero(x)
return EnzymeGradientPrep(grad_righttype)
Expand All @@ -257,21 +242,18 @@ end
function DI.gradient(
f::F,
::EnzymeGradientPrep,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
x,
contexts::Vararg{Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
ders = gradient(reverse_noprimal(backend), f_and_df, x, map(translate, contexts)...)
grad = first(ders)
return grad
return DI.gradient(f, backend, x, contexts...)
end

function DI.gradient!(
f::F,
grad,
prep::EnzymeGradientPrep,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
x,
contexts::Vararg{Context,C},
) where {F,C}
Expand All @@ -292,23 +274,18 @@ end
function DI.value_and_gradient(
f::F,
::EnzymeGradientPrep,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
x,
contexts::Vararg{Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
ders, y = gradient(
reverse_withprimal(backend), f_and_df, x, map(translate, contexts)...
)
grad = first(ders)
return y, grad
return DI.value_and_gradient(f, backend, x, contexts...)
end

function DI.value_and_gradient!(
f::F,
grad,
prep::EnzymeGradientPrep,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
x,
contexts::Vararg{Context,C},
) where {F,C}
Expand All @@ -328,6 +305,9 @@ end

## Jacobian

# TODO: does not support static arrays

#=
struct EnzymeReverseOneArgJacobianPrep{Sy,B} <: JacobianPrep end
function EnzymeReverseOneArgJacobianPrep(::Val{Sy}, ::Val{B}) where {Sy,B}
Expand Down Expand Up @@ -385,3 +365,4 @@ function DI.value_and_jacobian!(
y, new_jac = DI.value_and_jacobian(f, prep, backend, x)
return y, copyto!(jac, new_jac)
end
=#
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,13 @@ end
function maybe_reshape(A::AbstractArray, m, n)
return reshape(A, m, n)
end

annotate(::Type{Active{T}}, x, dx) where {T} = Active(x)
annotate(::Type{Duplicated{T}}, x, dx) where {T} = Duplicated(x, dx)

function annotate(::Type{BatchDuplicated{T,B}}, x, tx::NTuple{B}) where {T,B}
return BatchDuplicated(x, tx)
end

batchify_activity(::Type{Active{T}}, ::Val{B}) where {T,B} = Active{T}
batchify_activity(::Type{Duplicated{T}}, ::Val{B}) where {T,B} = BatchDuplicated{T,B}
13 changes: 13 additions & 0 deletions DifferentiationInterface/test/Back/Enzyme/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,16 @@ test_differentiation(
sparsity=true,
logging=LOGGING,
);

##

filtered_static_scenarios = filter(static_scenarios()) do s
DIT.operator_place(s) == :out && DIT.function_place(s) == :out
end

test_differentiation(
[AutoEnzyme(; mode=Enzyme.Forward), AutoEnzyme(; mode=Enzyme.Reverse)],
filtered_static_scenarios;
excluded=SECOND_ORDER,
logging=LOGGING,
)

0 comments on commit d4b17c1

Please sign in to comment.