diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 078bb602a..765b55b48 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -434,6 +434,21 @@ function rrule(::typeof(-), x::AbstractArray) return -x, negation_pullback end +##### +##### Subtraction +##### + +frule((_, Δx, Δy), ::typeof(-), x::AbstractArray, y::AbstractArray) = x - y, Δx - Δy + +function rrule(::typeof(-), x::AbstractArray, y::AbstractArray) + xproj = ProjectTo(x) + yproj = ProjectTo(y) + function subtract_pullback(dy_raw) + dy = unthunk(dy_raw) # projs will otherwise unthunk twice + return (NoTangent(), xproj(dy), yproj(-dy)) + end + return x - y, subtract_pullback +end ##### ##### Addition (Multiarg `+`) diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index 2682f3b8a..2099b9b3b 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -217,4 +217,13 @@ @gpu test_rrule(+, randn(4, 4), randn(4, 4), randn(4, 4)) @gpu test_rrule(+, randn(3), randn(3,1), randn(3,1,1)) end + + @testset "subtraction" begin + # fwd + @gpu test_frule(-, randn(2), randn(2)) + # rev + @gpu test_rrule(-, randn(4, 4), randn(4, 4)) + @gpu test_rrule(-, randn(4), randn(ComplexF64, 4)) + @gpu test_rrule(-, randn(3), randn(3, 1)) + end end