Skip to content

Commit

Permalink
Merge branch 'JuliaDiff:main' into kron
Browse files Browse the repository at this point in the history
  • Loading branch information
simsurace authored Jan 8, 2024
2 parents d5a8446 + ae37562 commit 3258dee
Show file tree
Hide file tree
Showing 22 changed files with 197 additions and 24 deletions.
1 change: 1 addition & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
style = "blue"
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- x86
- x64
steps:
- uses: actions/checkout@v4.0.0
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/Cancel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
cancel:
runs-on: ubuntu-latest
steps:
- uses: styfle/cancel-workflow-action@0.9.0
- uses: styfle/cancel-workflow-action@0.12.0
with:
all_but_latest: true
workflow_id: ${{ github.event.workflow.id }}
4 changes: 2 additions & 2 deletions .github/workflows/IntegrationTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ jobs:
# package: {user: JuliaDiff, repo: Diffractor.jl}

steps:
- uses: actions/checkout@v4.0.0
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.julia-version }}
arch: x64
- uses: julia-actions/julia-buildpkg@latest
- name: Clone Downstream
uses: actions/checkout@v4.0.0
uses: actions/checkout@v4
with:
repository: ${{ matrix.package.user }}/${{ matrix.package.repo }}
path: downstream
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/JuliaNightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
- x86
- x64
steps:
- uses: actions/checkout@v4.0.0
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/VersionVigilante_pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
VersionVigilante:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4.0.0
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@latest
- name: VersionVigilante.main
id: versionvigilante_main
Expand Down
27 changes: 27 additions & 0 deletions .github/workflows/format.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Format suggestions

on:
pull_request:

concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}

jobs:
format:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@latest
with:
version: 1
- run: |
julia -e 'using Pkg; Pkg.add("JuliaFormatter")'
julia -e 'using JuliaFormatter; format("."; verbose=true)'
- uses: reviewdog/action-suggester@v1
with:
tool_name: JuliaFormatter
fail_on_error: true
filter_mode: added
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.54.0"
version = "1.58.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -23,15 +23,21 @@ Adapt = "3.4.0"
ChainRulesCore = "1.15.3"
ChainRulesTestUtils = "1.5"
Compat = "3.46, 4.2"
Distributed = "1"
FiniteDifferences = "0.12.20"
GPUArraysCore = "0.1.0"
IrrationalConstants = "0.1.1, 0.2"
JLArrays = "0.1"
JuliaInterpreter = "0.8,0.9"
LinearAlgebra = "1"
Random = "1"
RealDot = "0.1"
SparseInverseSubset = "0.1"
SparseArrays = "1"
StaticArrays = "1.2"
Statistics = "1"
StructArrays = "0.6.11"
SuiteSparse = "1"
julia = "1.6"

[extras]
Expand Down
1 change: 1 addition & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ include("rulesets/Base/indexing.jl")
include("rulesets/Base/sort.jl")
include("rulesets/Base/mapreduce.jl")
include("rulesets/Base/broadcast.jl")
include("rulesets/Base/CoreLogging.jl")

include("rulesets/Distributed/nondiff.jl")

Expand Down
20 changes: 20 additions & 0 deletions src/rulesets/Base/CoreLogging.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# For the CoreLogging submodule of Base. (not to be confused with the Logging stdlib)

function rrule(
rc::RuleConfig{>:ChainRulesCore.HasReverseMode},
::typeof(Base.CoreLogging.with_logger),
f::Function,
logger::Base.CoreLogging.AbstractLogger,
)
y, f_pb = Base.CoreLogging.with_logger(logger) do
rrule_via_ad(rc, f)
end
with_logger_pullback(ȳ) = (NoTangent(), only(f_pb(ȳ)), NoTangent())
return y, with_logger_pullback
end

@non_differentiable Base.CoreLogging.current_logger(args...)
@non_differentiable Base.CoreLogging.current_logger_for_env(::Any...)
@non_differentiable Base.CoreLogging._invoked_shouldlog(::Any...)
@non_differentiable Base.CoreLogging.Base.fixup_stdlib_path(::Any)
@non_differentiable Base.CoreLogging.handle_message(::Any...)
6 changes: 3 additions & 3 deletions src/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R

function backslash_pullback(ȳ)
= unthunk(ȳ)

Ȳf =
@static if VERSION >= v"1.9"
# Need to ensure Ȳ is an array since since https://github.com/JuliaLang/julia/pull/44358
Expand All @@ -360,7 +360,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
end
end
Yf = Y
@static if VERSION >= v"1.9"
@static if VERSION >= v"1.9"
# Need to ensure Yf is an array since since https://github.com/JuliaLang/julia/pull/44358
if !isa(Y, AbstractArray)
Yf = [Y]
Expand All @@ -371,7 +371,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
= A' \ Ȳf
= -* Y'
t = (B - A * Y) *'
@static if VERSION >= v"1.9"
@static if VERSION >= v"1.9"
# Need to ensure t is an array since since https://github.com/JuliaLang/julia/pull/44358
if !isa(t, AbstractArray)
t = [t]
Expand Down
1 change: 1 addition & 0 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ end

@scalar_rule fma(x, y, z) (y, x, true)
@scalar_rule muladd(x, y, z) (y, x, true)
@scalar_rule muladd(x::Union{Number, ZeroTangent}, y::Union{Number, ZeroTangent}, z::Union{Number, ZeroTangent}) (y, x, true)
@scalar_rule rem2pi(x, r::RoundingMode) (true, NoTangent())
@scalar_rule(
mod(x, y),
Expand Down
10 changes: 7 additions & 3 deletions src/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Int rather than Int64/Integer is intentional
function frule((_, ẋ), ::typeof(getfield), x::Tuple, i::Int)
return x.i, ẋ.i
function ChainRulesCore.frule((_, Δ, _), ::typeof(getfield), strct, sym::Union{Int,Symbol})
return (getfield(strct, sym), isa(Δ, NoTangent) ? NoTangent() : getproperty(Δ, sym))
end

function ChainRulesCore.frule((_, Δ, _, _), ::typeof(getfield), strct, sym::Union{Int,Symbol}, inbounds)
return (getfield(strct, sym, inbounds), isa(Δ, NoTangent) ? NoTangent() : getproperty(Δ, sym))
end

"for a given tuple type, returns a Val{N} where N is the length of the tuple"
Expand Down Expand Up @@ -140,7 +144,7 @@ end
ChainRules.@non_differentiable _setindex_zero(x::AbstractArray, dy::Any, inds::Any...)

function ∇getindex!(dx::AbstractArray, dy, inds::Integer...)
view(dx, inds...) .+= Ref(dy)
@views dx[inds...] += dy
return dx
end
function ∇getindex!(dx::AbstractArray, dy, inds...)
Expand Down
4 changes: 0 additions & 4 deletions src/rulesets/Base/nondiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -483,10 +483,6 @@ end
@non_differentiable Broadcast.result_style(::Any)
@non_differentiable Broadcast.result_style(::Any, ::Any)

@non_differentiable Base.CoreLogging.current_logger_for_env(::Any...)
@non_differentiable Base.CoreLogging._invoked_shouldlog(::Any...)
@non_differentiable Base.CoreLogging.Base.fixup_stdlib_path(::Any)
@non_differentiable Base.CoreLogging.handle_message(::Any...)

@non_differentiable Libc.free(::Any)
@non_differentiable Libc.getpid()
Expand Down
2 changes: 1 addition & 1 deletion src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ end

function _diagm_back(p, ȳ)
k, v = p
d = diag(unthunk(ȳ), k)[1:length(v)] # handle if diagonal was smaller than matrix
d = diag(unthunk(ȳ), k)[eachindex(v)] # handle if diagonal was smaller than matrix
return Tangent{typeof(p)}(second = d)
end

Expand Down
23 changes: 23 additions & 0 deletions src/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,26 @@ function rrule(::typeof(det), x::SparseMatrixCSC)
end
return Ω, det_pullback
end


function rrule(::typeof(spdiagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...)

function spdiagm_pullback(ȳ)
return (NoTangent(), NoTangent(), NoTangent(), _diagm_back.(kv, Ref(ȳ))...)
end
return spdiagm(m, n, kv...), spdiagm_pullback
end

function rrule(::typeof(spdiagm), kv::Pair{<:Integer,<:AbstractVector}...)
function spdiagm_pullback(ȳ)
return (NoTangent(), _diagm_back.(kv, Ref(ȳ))...)
end
return spdiagm(kv...), spdiagm_pullback
end

function rrule(::typeof(spdiagm), v::AbstractVector)
function spdiagm_pullback(ȳ)
return (NoTangent(), diag(unthunk(ȳ)))
end
return spdiagm(v), spdiagm_pullback
end
11 changes: 11 additions & 0 deletions test/rulesets/Base/CoreLogging.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# For the CoreLogging submodule of Base. (not to be confused with the Logging stdlib)
@testset "CoreLogging.jl" begin
@testset "with_logger" begin
test_rrule(
Base.CoreLogging.with_logger,
() -> 2.0 * 3.0,
Base.CoreLogging.NullLogger();
check_inferred=false,
)
end
end
10 changes: 10 additions & 0 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,16 @@
test_rrule(muladd, 10randn(), randn(), randn())
end

@testset "muladd ZeroTangent" begin
test_frule(muladd, 2.0, 3.0, ZeroTangent())
test_frule(muladd, 2.0, ZeroTangent(), 4.0)
test_frule(muladd, ZeroTangent(), 3.0, 4.0)

test_rrule(muladd, 2.0, 3.0, ZeroTangent())
test_rrule(muladd, 2.0, ZeroTangent(), 4.0)
test_rrule(muladd, ZeroTangent(), 3.0, 4.0)
end

@testset "fma" begin
test_frule(fma, 10randn(), randn(), randn())
test_rrule(fma, 10randn(), randn(), randn())
Expand Down
24 changes: 24 additions & 0 deletions test/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
struct FooTwoField
x::Float64
y::Float64
end


@testset "getfield" begin
test_frule(getfield, FooTwoField(1.5, 2.5), :x, check_inferred=false)

test_frule(getfield, (; a=1.5, b=2.5), :a, check_inferred=false)
test_frule(getfield, (; a=1.5, b=2.5), 2)

test_frule(getfield, (1.5, 2.5), 2)
test_frule(getfield, (1.5, 2.5), 2, true)
end

@testset "getindex" begin
@testset "getindex(::Tuple, ...)" begin
x = (1.2, 3.4, 5.6)
Expand Down Expand Up @@ -161,6 +177,14 @@
@test Array(y3) == Array(x_23_gpu)[1, [1,1,2]]
@test unthunk(bk3(jl(ones(3)))[2]) == jl([2 1 0; 0 0 0])
end

@testset "getindex(::Array{<:AbstractGPUArray})" begin
x_gpu = jl(rand(1))
y, back = rrule(getindex, [x_gpu], 1)
@test y === x_gpu
dxs_gpu = unthunk(back(jl([1.0]))[2])
@test dxs_gpu == [jl([1.0])]
end
end

# first & tail handled by getfield rules
Expand Down
47 changes: 46 additions & 1 deletion test/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,51 @@ end
test_rrule(SparseVector{Float32}, Float32.(v), rtol=1e-4)
end

# copied over from test/rulesets/LinearAlgebra/structured
@testset "spdiagm" begin
@testset "without size" begin
M, N = 7, 9
s = (8, 8)
a = randn(M)
b = randn(M)
c = randn(M - 1)
= randn(s)
ps = (0 => a, 1 => b, 0 => c)
y, back = rrule(spdiagm, ps...)
@test y == spdiagm(ps...)
∂self, ∂pa, ∂pb, ∂pc = back(ȳ)
@test ∂self === NoTangent()
∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(0 => a, 1 => b, 0 => c), ȳ, a, b, c)
for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd))
∂px = unthunk(∂px)
@test ∂px isa Tangent{typeof(p)}
@test ∂px.first isa AbstractZero
@test ∂px.second ∂x_fd
end
end
@testset "with size" begin
M, N = 7, 9
a = randn(M)
b = randn(M)
c = randn(M - 1)
= randn(M, N)
ps = (0 => a, 1 => b, 0 => c)
y, back = rrule(spdiagm, M, N, ps...)
@test y == spdiagm(M, N, ps...)
∂self, ∂M, ∂N, ∂pa, ∂pb, ∂pc = back(ȳ)
@test ∂self === NoTangent()
@test ∂M === NoTangent()
@test ∂N === NoTangent()
∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(M, N, 0 => a, 1 => b, 0 => c), ȳ, a, b, c)
for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd))
∂px = unthunk(∂px)
@test ∂px isa Tangent{typeof(p)}
@test ∂px.first isa AbstractZero
@test ∂px.second ∂x_fd
end
end
end

@testset "findnz" begin
A = sprand(5, 5, 0.5)
dA = similar(A)
Expand All @@ -42,4 +87,4 @@ end
test_rrule(logabsdet, A)
test_rrule(logdet, A)
test_rrule(det, A)
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ end
test_method_tables() # Check the global method tables are consistent

# Each file puts all tests inside one or more @testset blocks
include_test("rulesets/Base/CoreLogging.jl")
include_test("rulesets/Base/base.jl")
include_test("rulesets/Base/fastmath_able.jl")
include_test("rulesets/Base/evalpoly.jl")
Expand Down
13 changes: 8 additions & 5 deletions test/unzipped.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,14 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map
# TODO invent some tests of this rrule's pullback function

@test unzip(jl([(1,2), (3,4), (5,6)])) == (jl([1, 3, 5]), jl([2, 4, 6]))

@test unzip(jl([(missing,2), (missing,4), (missing,6)]))[2] == jl([2, 4, 6])
@test unzip(jl([(missing,2), (missing,4), (missing,6)]))[2] isa Base.ReinterpretArray

@test unzip(jl([(1,), (3,), (5,)]))[1] == jl([1, 3, 5])
@test unzip(jl([(1,), (3,), (5,)]))[1] isa Base.ReinterpretArray

# depending on Julia/package versions, may get ReinterpretArray or JLArray
# Either is acceptable
@test isa(
unzip(jl([(missing, 2), (missing, 4), (missing, 6)]))[2],
Union{Base.ReinterpretArray,JLArray},
)
end
end
end

0 comments on commit 3258dee

Please sign in to comment.