Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rrule for spdiagm #740

Merged
merged 10 commits into from
Oct 24, 2023
2 changes: 1 addition & 1 deletion src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ end

function _diagm_back(p, ȳ)
k, v = p
d = diag(unthunk(ȳ), k)[1:length(v)] # handle if diagonal was smaller than matrix
d = diag(unthunk(ȳ), k)[eachindex(v)] # handle if diagonal was smaller than matrix
return Tangent{typeof(p)}(second = d)
end

Expand Down
23 changes: 23 additions & 0 deletions src/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,26 @@ function rrule(::typeof(det), x::SparseMatrixCSC)
end
return Ω, det_pullback
end


function rrule(::typeof(spdiagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...)

function spdiagm_pullback(ȳ)
return (NoTangent(), NoTangent(), NoTangent(), _diagm_back.(kv, Ref(ȳ))...)
end
return spdiagm(m, n, kv...), spdiagm_pullback
end

function rrule(::typeof(spdiagm), kv::Pair{<:Integer,<:AbstractVector}...)
function spdiagm_pullback(ȳ)
return (NoTangent(), _diagm_back.(kv, Ref(ȳ))...)
end
return spdiagm(kv...), spdiagm_pullback
end

function rrule(::typeof(spdiagm), v::AbstractVector)
function spdiagm_pullback(ȳ)
return (NoTangent(), diag(unthunk(ȳ)))
end
return spdiagm(v), spdiagm_pullback
end
47 changes: 46 additions & 1 deletion test/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,51 @@ end
test_rrule(SparseVector{Float32}, Float32.(v), rtol=1e-4)
end

# copied over from test/rulesets/LinearAlgebra/structured
@testset "spdiagm" begin
@testset "without size" begin
M, N = 7, 9
s = (8, 8)
a = randn(M)
b = randn(M)
c = randn(M - 1)
ȳ = randn(s)
ps = (0 => a, 1 => b, 0 => c)
y, back = rrule(spdiagm, ps...)
ElOceanografo marked this conversation as resolved.
Show resolved Hide resolved
@test y == spdiagm(ps...)
∂self, ∂pa, ∂pb, ∂pc = back(ȳ)
@test ∂self === NoTangent()
∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(0 => a, 1 => b, 0 => c), ȳ, a, b, c)
for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd))
∂px = unthunk(∂px)
@test ∂px isa Tangent{typeof(p)}
@test ∂px.first isa AbstractZero
@test ∂px.second ≈ ∂x_fd
end
end
@testset "with size" begin
M, N = 7, 9
a = randn(M)
b = randn(M)
c = randn(M - 1)
ȳ = randn(M, N)
ps = (0 => a, 1 => b, 0 => c)
y, back = rrule(spdiagm, M, N, ps...)
@test y == spdiagm(M, N, ps...)
∂self, ∂M, ∂N, ∂pa, ∂pb, ∂pc = back(ȳ)
@test ∂self === NoTangent()
@test ∂M === NoTangent()
@test ∂N === NoTangent()
∂a_fd, ∂b_fd, ∂c_fd = j′vp(_fdm, (a, b, c) -> spdiagm(M, N, 0 => a, 1 => b, 0 => c), ȳ, a, b, c)
for (p, ∂px, ∂x_fd) in zip(ps, (∂pa, ∂pb, ∂pc), (∂a_fd, ∂b_fd, ∂c_fd))
∂px = unthunk(∂px)
@test ∂px isa Tangent{typeof(p)}
@test ∂px.first isa AbstractZero
@test ∂px.second ≈ ∂x_fd
end
end
end

@testset "findnz" begin
A = sprand(5, 5, 0.5)
dA = similar(A)
Expand All @@ -42,4 +87,4 @@ end
test_rrule(logabsdet, A)
test_rrule(logdet, A)
test_rrule(det, A)
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ end

ElOceanografo marked this conversation as resolved.
Show resolved Hide resolved
println()

include_test("rulesets/SparseArrays/sparsematrix.jl")
include("rulesets/SparseArrays/sparsematrix.jl")
oxinabox marked this conversation as resolved.
Show resolved Hide resolved

println()

Expand Down
Loading