Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
vpuri3 committed Aug 2, 2022
1 parent 2ecde1c commit a70190a
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions examples/ad/zy.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#
using Zygote, LinearAlgebra

N = 4
u0 = rand(N)
ps = rand(N)

mats = (rand(N,N), rand(N,N),) # (A, B,)
nums = (rand(), rand(),) # (α, β,)

loss_m = function(p)
v = Diagonal(p) * u0
v = Zygote.hook-> (println("Δv: ", typeof(Δ)); Δ), v)

w = foldl((acc, op) -> op * acc, mats; init=v) # w = B * A * v
w = Zygote.hook-> (println("Δw: ", Δ); Δ), w)

l = sum(w)
l = Zygote.hook-> (println("Δl: ", Δ); Δ), l)
end

println("fwd"); @time loss_m(ps) |> display
println("bwd"); @time Zygote.gradient(loss_m, ps) |> display # INCORRECT - should not vanish

loss_n = function(p)
v = Diagonal(p) * u0
v = Zygote.hook-> (println("Δv: ", typeof(Δ)); Δ), v)

w = sum(a -> convert(Number, a), nums; init=zero(eltype(nums))) * v # w = αβ * v
#w = sum(a -> convert(Number, a), nums) * v # w = αβ * v
w = Zygote.hook-> (println("Δw: ", Δ); Δ), w)

l = sum(w)
l = Zygote.hook-> (println("Δl: ", Δ); Δ), l)
end

println("fwd"); @time loss_n(ps) |> display
println("bwd"); @time Zygote.gradient(loss_n, ps) |> display # ERRORS

#

0 comments on commit a70190a

Please sign in to comment.