Skip to content

Commit

Permalink
Merge #780
Browse files Browse the repository at this point in the history
780: Move a bunch of no_grad to ChainRules r=oxinabox a=oxinabox

this is the partner to JuliaDiff/ChainRules.jl#252
It will fail til that is merged and tagged

What is left is:

- Types (because JuliaDiff/ChainRulesCore.jl#213) (e.g. `Colon`, `OneTo` `Channel`)
- Things to which the derivative is `Zero()` not `DoesNotExist()` (e.g. `one`, `ones`, `zero`, `zeros`)
- Things that felt too magic: e.g. `Base.eval`


Should I bump patch version and tag a release?

Co-authored-by: Lyndon White <[email protected]>
Co-authored-by: Lyndon White <[email protected]>
  • Loading branch information
3 people authored Sep 4, 2020
2 parents 4ea7ad7 + a2026e7 commit ec3fad4
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 23 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.5.5"
version = "0.5.6"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -27,7 +27,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5"
ArrayLayouts = "0.1, 0.2, 0.3, 0.4"
ChainRules = "0.7.0"
ChainRules = "0.7.16"
DiffRules = "1.0"
FillArrays = "0.8, 0.9"
ForwardDiff = "0"
Expand Down
2 changes: 2 additions & 0 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ Convert `x` from the differentials types ChainRules uses to the format Zygote us
"""
@inline wrap_chainrules_output(x) = unthunk(x) # For now we are just not going to deal with thunks
@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
# Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing.
@inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing
@inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing
for T_outer in (:Tuple, :NamedTuple)
# we create separate methods rather than using a `Union` + an `if` so that we avoid a
Expand Down
9 changes: 1 addition & 8 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,7 @@ using Distributed: pmap
@adjoint Array(xs::AbstractArray) = Array(xs), ȳ -> (ȳ,)
@adjoint Array(xs::Array) = Array(xs), ȳ -> (ȳ,)

@nograd size, length, eachindex, Base.OneTo, axes, Colon(), findfirst, findlast, findall, ones, zeros, one, zero, any, all
@nograd randn, randexp, randn!, randexp!
@static if VERSION > v"1.3"
@nograd Random.default_rng
end

@adjoint Base.rand(rng::AbstractRNG, ::Type{T}, dims...) where {T<:Number} =
rand(rng, T, dims...), _ -> nothing
@nograd ones, zeros, Base.OneTo, Colon(), one, zero

@adjoint Base.vect(xs...) = Base.vect(xs...), Δ ->...,)

Expand Down
7 changes: 1 addition & 6 deletions src/lib/base.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
@nograd readline, Base.gc_num, Base.time_ns, Base.print, Base.println, Base.show,
Core.show, Core.print, Core.println, string, repr, Threads.nthreads, Threads.threadid

# Gradient of AD stacks

grad_mut(::AbstractVector) = []
Expand Down Expand Up @@ -47,11 +44,9 @@ end
end
end

@nograd haskey

# Channels

@nograd Channel, schedule
@nograd Channel

grad_mut(ch::Channel) = Channel(ch.sz_max)

Expand Down
2 changes: 0 additions & 2 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ using Base.Broadcast
using Base.Broadcast: Broadcasted, AbstractArrayStyle, broadcasted, materialize
using NNlib

@nograd Broadcast.combine_styles, Broadcast.result_style

# There's a saying that debugging code is about twice as hard as writing it in
# the first place. So if you're as clever as you can be when writing code, how
# will you ever debug it?
Expand Down
5 changes: 1 addition & 4 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ function accum(x::RefValue, y::RefValue)
end

# Core functions

@nograd Core.apply_type, Core.typeof, nfields, fieldtype, Core.TypeVar, Core.UnionAll,
(==), (===), (<=), (>=), (<), (>), isempty, supertype, Base.typename,
eps, Meta.parse, Base.eval, sleep, isassigned
@nograd eps, Base.eval, Core.TypeVar, Core.UnionAll

@adjoint deepcopy(x) = deepcopy(x), ȳ -> (ȳ,)

Expand Down
2 changes: 1 addition & 1 deletion src/lib/number.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

@nograd floor, ceil, trunc, round, hash, div
@nograd floor, ceil, trunc, round, div

@adjoint Base.literal_pow(::typeof(^), x::Number, ::Val{p}) where {p} =
Base.literal_pow(^,x,Val(p)),
Expand Down
16 changes: 16 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,22 @@ using Zygote, Test, ChainRules
@test mimo_pullback_hitcount[] == 1
end

@testset "all AbstractZero partials" begin
# while ChainRules always has a partial for every input, Zygote combined them all
# to a single `nothing` if they are all zero-like.

not_diff_eg(x, i) = [10, 20][i]
function ChainRules.rrule(::typeof(not_diff_eg), x, i)
function not_diff_eg_pullback(Δ)
return ChainRules.NO_FIELDS, ChainRules.Zero(), ChainRules.DoesNotExist()
end
return not_diff_eg(x, i), not_diff_eg_pullback
end

_, pb = Zygote.pullback(not_diff_eg, 10.4, 2)
@test pb(1.2) === nothing
end

@testset "nested AD hitting identity(::Tuple) pullback" begin
# This is is a particularly fiddly case.
# Its kind of a simplified version of `sin'''(0.5)` but different in some places.
Expand Down
5 changes: 5 additions & 0 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1538,9 +1538,14 @@ end
end

@testset "@nograd" begin
@test gradient(x->eachindex([10,20,30])[1], 11) == (nothing,)

#These are defined in ChainRules, we test them here to check we are handling them right
@test gradient(x -> findfirst(ismissing, x), [1, missing]) == (nothing,)
@test gradient(x -> findlast(ismissing, x), [1, missing]) == (nothing,)
@test gradient(x -> findall(ismissing, x)[1], [1, missing]) == (nothing,)


@test gradient(x -> Zygote.ignore(() -> x*x), 1) == (nothing,)
@test gradient(x -> Zygote.@ignore(x*x), 1) == (nothing,)
@test gradient(1) do x
Expand Down

1 comment on commit ec3fad4

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/20841

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.6 -m "<description of version>" ec3fad4ee08f904d1a187fa06d02cb0553670b5c
git push origin v0.5.6

Please sign in to comment.