diff --git a/Project.toml b/Project.toml index 330f8b5465..01ac5cdeaa 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "0.14.17" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" @@ -26,6 +25,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" @@ -33,6 +33,7 @@ cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" FluxAMDGPUExt = "AMDGPU" FluxCUDAExt = "CUDA" FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] +FluxEnzymeExt = "Enzyme" FluxMetalExt = "Metal" [compat] diff --git a/ext/FluxEnzymeExt/FluxEnzymeExt.jl b/ext/FluxEnzymeExt/FluxEnzymeExt.jl new file mode 100644 index 0000000000..e6ce51297f --- /dev/null +++ b/ext/FluxEnzymeExt/FluxEnzymeExt.jl @@ -0,0 +1,47 @@ +module FluxEnzymeExt + +using Flux +import Flux.Train: train!, _rule_to_state +import Flux.Optimise +import Optimisers +import Enzyme +using Enzyme: EnzymeRules, Active, Const, Duplicated, autodiff, ReverseWithPrimal +using ProgressLogging: @withprogress, @logprogress + +_make_zero_internal!(x::AbstractArray) = fill!(x, 0) +_make_zero_internal!(x) = x +_make_zero!(model) = fmap(_make_zero_internal!, model) + +_applyloss(loss, model, d...) = loss(model, d...) + +EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true + +using Flux: _old_to_new # from src/deprecations.jl +train!(loss, model::Duplicated, data, opt::Optimise.AbstractOptimiser; cb=nothing) = + train!(loss, model, data, _old_to_new(opt); cb) + +function train!(loss, model::Duplicated, data, rule::Optimisers.AbstractRule; cb = nothing) + train!(loss, model, data, _rule_to_state(model, rule); cb) +end + +function train!(loss, model::Duplicated, data, opt; cb = nothing) + isnothing(cb) || error("""train! does not support callback functions. + For more control use a loop with `gradient` and `update!`.""") + @withprogress for (i,d) in enumerate(data) + d_splat = d isa Tuple ? d : (d,) + + _make_zero!(model.dval) + _, l = Enzyme.autodiff(ReverseWithPrimal, _applyloss, + Active, Const(loss), model, map(Const, d_splat)...) + + if !isfinite(l) + throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) + end + opt, model2 = Optimisers.update!(opt, model.val, model.dval) + model = Duplicated(model2, model.dval) + + @logprogress Base.haslength(data) ? i/length(data) : nothing + end +end + +end # FluxEnzymeExt \ No newline at end of file diff --git a/src/deprecations.jl b/src/deprecations.jl index 24372a570e..9306671494 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -109,8 +109,7 @@ train!(loss, ps::Params, data, opt::Optimisers.AbstractRule; cb=nothing) = error train!(loss, model, data, opt::Optimise.AbstractOptimiser; cb=nothing) = train!(loss, model, data, _old_to_new(opt); cb) -train!(loss, model::Enzyme.Duplicated, data, opt::Optimise.AbstractOptimiser; cb=nothing) = - train!(loss, model, data, _old_to_new(opt); cb) + # Next, to use the new `setup` with the still-exported old-style `Adam` etc: import .Train: setup diff --git a/src/functor.jl b/src/functor.jl index e48246ebde..eeaffab1c3 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -3,7 +3,6 @@ using LinearAlgebra: Cholesky using Zygote: IdSet import Functors: Functors, @functor, functor, fmap, isleaf using SparseArrays: AbstractSparseArray -using Enzyme """ testmode!(model, [mode]) -> model diff --git a/src/losses/utils.jl b/src/losses/utils.jl index c380564908..b59032d726 100644 --- a/src/losses/utils.jl +++ b/src/losses/utils.jl @@ -1,4 +1,3 @@ -import Enzyme """ xlogx(x) @@ -38,4 +37,3 @@ end _check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1 ChainRulesCore.@non_differentiable _check_sizes(ŷ::Any, y::Any) -Enzyme.EnzymeRules.inactive(::typeof(_check_sizes), args...) = true diff --git a/src/train.jl b/src/train.jl index 6094e13ac6..d2cbbd40fa 100644 --- a/src/train.jl +++ b/src/train.jl @@ -5,7 +5,6 @@ using Optimisers: Optimisers using Functors: fmap, fmapstructure using ..Flux: Flux # used only in docstring import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions -import Enzyme export setup, train! @@ -53,12 +52,6 @@ function setup(rule::Optimisers.AbstractRule, model) state end -_make_zero_internal!(x::AbstractArray) = fill!(x, 0) -_make_zero_internal!(x) = x -_make_zero!(model) = fmap(_make_zero_internal!, model) - -_applyloss(loss, model, d...) = loss(model, d...) - """ train!(loss, model, data, opt_state) @@ -67,7 +60,7 @@ according to a particular optimisation rule encoded in `opt_state`. Iterates through `data` once, evaluating for each `d in data` either `loss(model, d...)` if `d isa Tuple`, or else `loss(model, d)` for other `d`. -If `model` is an Enzyme.Duplicated, gradients will be computed with Enzyme, +If `model` is an Enzyme.Duplicated and `Enzyme.jl` is loaded, gradients will be computed with Enzyme, otherwise they will be computed with Zygote. For example, with these definitions... @@ -122,32 +115,12 @@ function train!(loss, model, data, opt; cb = nothing) @logprogress Base.haslength(data) ? i/length(data) : nothing end end -function train!(loss, model::Enzyme.Duplicated, data, opt; cb = nothing) - isnothing(cb) || error("""train! does not support callback functions. - For more control use a loop with `gradient` and `update!`.""") - @withprogress for (i,d) in enumerate(data) - d_splat = d isa Tuple ? d : (d,) - - _make_zero!(model.dval) - _, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, _applyloss, Enzyme.Active, Enzyme.Const(loss), model, map(Enzyme.Const, d_splat)...) - - if !isfinite(l) - throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) - end - opt, model2 = Optimisers.update!(opt, model.val, model.dval) - model = Enzyme.Duplicated(model2, model.dval) - @logprogress Base.haslength(data) ? i/length(data) : nothing - end -end # This method let you use Optimisers.Descent() without setup, when there is no state function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing) train!(loss, model, data, _rule_to_state(model, rule); cb) end -function train!(loss, model::Enzyme.Duplicated, data, rule::Optimisers.AbstractRule; cb = nothing) - train!(loss, model, data, _rule_to_state(model, rule); cb) -end function _rule_to_state(model, rule::Optimisers.AbstractRule) state = setup(rule, model) diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index 8241a3f8dd..25284c5a3f 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -1,15 +1,12 @@ using Test using Flux -using Enzyme +using Enzyme: Enzyme, make_zero, Active, Duplicated, ReverseWithPrimal + using Functors using FiniteDifferences using CUDA -_make_zero(x::Union{Number,AbstractArray}) = zero(x) -_make_zero(x) = x -make_zero(model) = fmap(_make_zero, model) -## make_differential(model) = fmapstructure(make_zero, model) # NOT SUPPORTED, See https://github.com/EnzymeAD/Enzyme.jl/issues/1329 function gradient_fd(f, x...) x = [cpu(x) for x in x] diff --git a/test/train.jl b/test/train.jl index 3ed0e658ea..4c0c12b1b6 100644 --- a/test/train.jl +++ b/test/train.jl @@ -4,10 +4,10 @@ import Optimisers using Test using Random -using Enzyme +import Enzyme function train_enzyme!(fn, model, args...; kwargs...) - Flux.train!(fn, Duplicated(model, Enzyme.make_zero(model)), args...; kwargs...) + Flux.train!(fn, Enzyme.Duplicated(model, Enzyme.make_zero(model)), args...; kwargs...) end for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme"))