From a70190ad70fad71d1eac90e1b866274737fee4e0 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Tue, 2 Aug 2022 17:12:41 -0400 Subject: [PATCH] zygote MWE https://github.com/FluxML/Zygote.jl/issues/1279 --- examples/ad/zy.jl | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 examples/ad/zy.jl diff --git a/examples/ad/zy.jl b/examples/ad/zy.jl new file mode 100644 index 0000000..b7071d7 --- /dev/null +++ b/examples/ad/zy.jl @@ -0,0 +1,40 @@ +# +using Zygote, LinearAlgebra + +N = 4 +u0 = rand(N) +ps = rand(N) + +mats = (rand(N,N), rand(N,N),) # (A, B,) +nums = (rand(), rand(),) # (α, β,) + +loss_m = function(p) + v = Diagonal(p) * u0 + v = Zygote.hook(Δ -> (println("Δv: ", typeof(Δ)); Δ), v) + + w = foldl((acc, op) -> op * acc, mats; init=v) # w = B * A * v + w = Zygote.hook(Δ -> (println("Δw: ", Δ); Δ), w) + + l = sum(w) + l = Zygote.hook(Δ -> (println("Δl: ", Δ); Δ), l) +end + +println("fwd"); @time loss_m(ps) |> display +println("bwd"); @time Zygote.gradient(loss_m, ps) |> display # INCORRECT - should not vanish + +loss_n = function(p) + v = Diagonal(p) * u0 + v = Zygote.hook(Δ -> (println("Δv: ", typeof(Δ)); Δ), v) + + w = sum(a -> convert(Number, a), nums; init=zero(eltype(nums))) * v # w = αβ * v + #w = sum(a -> convert(Number, a), nums) * v # w = αβ * v + w = Zygote.hook(Δ -> (println("Δw: ", Δ); Δ), w) + + l = sum(w) + l = Zygote.hook(Δ -> (println("Δl: ", Δ); Δ), l) +end + +println("fwd"); @time loss_n(ps) |> display +println("bwd"); @time Zygote.gradient(loss_n, ps) |> display # ERRORS + +#