Skip to content

Commit

Permalink
restrict muladd frule and add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Jul 11, 2023
1 parent f0095e0 commit 8b9cd00
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
8 changes: 7 additions & 1 deletion src/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,13 @@ end # VERSION
##### `muladd`
#####

function frule((_, ΔA, ΔB, Δz), ::typeof(muladd), A, B, z)
function frule(
(_, ΔA, ΔB, Δz),
::typeof(muladd),
A::AbstractVecOrMat{<:CommutativeMulNumber},
B::AbstractVecOrMat{<:CommutativeMulNumber},
z::Union{CommutativeMulNumber, AbstractVecOrMat{<:CommutativeMulNumber}}
)
Ω = muladd(A, B, z)
return Ω, ΔA * B .+ A * ΔB .+ Δz
end
Expand Down
11 changes: 8 additions & 3 deletions test/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,44 +85,49 @@

@testset "muladd: $T" for T in (Float64, ComplexF64)
@testset "add $(typeof(z))" for z in [rand(), rand(T, 3), rand(T, 3, 3), false]
@testset "forward mode" begin
@gpu test_frule(muladd, rand(T, 3, 5), rand(T, 5, 3), z)
end
@testset "matrix * matrix" begin
A = rand(T, 3, 3)
B = rand(T, 3, 3)
@gpu test_rrule(muladd, A, B, z)
@gpu test_rrule(muladd, A', B, z)
@gpu test_rrule(muladd, A , B', z)
@gpu test_frule(muladd, A, B, z)
@gpu test_frule(muladd, A', B, z)
@gpu test_frule(muladd, A , B', z)

C = rand(T, 3, 5)
D = rand(T, 5, 3)
@gpu test_rrule(muladd, C, D, z)
@gpu test_frule(muladd, C, D, z)
end
if ndims(z) <= 1
@testset "matrix * vector" begin
A, B = rand(T, 3, 3), rand(T, 3)
test_rrule(muladd, A, B, z)
test_rrule(muladd, A, B rand(T, 3,1), z)
test_frule(muladd, A, B, z)
end
@testset "adjoint * matrix" begin
At, B = rand(T, 3)', rand(T, 3, 3)
test_rrule(muladd, At, B, z')
test_rrule(muladd, At rand(T,1,3), B, z')
test_frule(muladd, At, B, z')
end
end
if ndims(z) == 0
@testset "adjoint * vector" begin # like dot
A, B = rand(T, 3)', rand(T, 3)
test_rrule(muladd, A, B, z)
test_rrule(muladd, A rand(T,1,3), B, z')
test_frule(muladd, A, B, z)
end
end
if ndims(z) == 2 # other dims lead to e.g. muladd(ones(4), ones(1,4), 1)
@testset "vector * adjoint" begin # outer product
A, B = rand(T, 3), rand(T, 3)'
test_rrule(muladd, A, B, z)
test_rrule(muladd, A, B rand(T,1,3), z)
test_frule(muladd, A, B, z)
end
end
end
Expand Down

0 comments on commit 8b9cd00

Please sign in to comment.