From 2c3e48d005c210c60de4cd286beb1d2c83956190 Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Fri, 9 Feb 2024 11:31:25 +0100 Subject: [PATCH 1/5] add subtract rule --- src/rulesets/Base/arraymath.jl | 10 ++++++++++ test/rulesets/Base/arraymath.jl | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 078bb602a..fa0b34d53 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -434,6 +434,16 @@ function rrule(::typeof(-), x::AbstractArray) return -x, negation_pullback end +##### +##### Subtraction +##### + +frule((_, Δx, Δy), ::typeof(-), x::AbstractArray, y::AbstractArray) = x-y, Δx-Δy + +function rrule(::typeof(-), x::AbstractArray, y::AbstractArray) + subtract_pullback(dy) = (NoTangent(), dy, -dy) + return x-y, subtract_pullback +end ##### ##### Addition (Multiarg `+`) diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index 2682f3b8a..2b4970b56 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -217,4 +217,12 @@ @gpu test_rrule(+, randn(4, 4), randn(4, 4), randn(4, 4)) @gpu test_rrule(+, randn(3), randn(3,1), randn(3,1,1)) end + + @testset "subtraction" begin + # fwd + @gpu test_frule(-, randn(2), randn(2)) + # rev + @gpu test_rrule(-, randn(4, 4), randn(4, 4)) + @gpu test_rrule(-, randn(3), randn(3,1)) + end end From 98f893f0f74ff9d852a66064c88dde3322871341 Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Fri, 9 Feb 2024 11:43:31 +0100 Subject: [PATCH 2/5] format --- src/rulesets/Base/arraymath.jl | 4 ++-- test/rulesets/Base/arraymath.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index fa0b34d53..74ef2f579 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -438,11 +438,11 @@ end ##### Subtraction ##### -frule((_, Δx, Δy), ::typeof(-), x::AbstractArray, y::AbstractArray) = x-y, Δx-Δy +frule((_, Δx, Δy), ::typeof(-), x::AbstractArray, y::AbstractArray) = x - y, Δx - Δy function rrule(::typeof(-), x::AbstractArray, y::AbstractArray) subtract_pullback(dy) = (NoTangent(), dy, -dy) - return x-y, subtract_pullback + return x - y, subtract_pullback end ##### diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index 2b4970b56..6894911cd 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -223,6 +223,6 @@ @gpu test_frule(-, randn(2), randn(2)) # rev @gpu test_rrule(-, randn(4, 4), randn(4, 4)) - @gpu test_rrule(-, randn(3), randn(3,1)) + @gpu test_rrule(-, randn(3), randn(3, 1)) end end From 40fd8dc0f2813b1a5280e6f52f7b5221355d90b7 Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Fri, 9 Feb 2024 12:50:52 +0100 Subject: [PATCH 3/5] use ProjectTo --- src/rulesets/Base/arraymath.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 74ef2f579..18c70a1c5 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -441,7 +441,12 @@ end frule((_, Δx, Δy), ::typeof(-), x::AbstractArray, y::AbstractArray) = x - y, Δx - Δy function rrule(::typeof(-), x::AbstractArray, y::AbstractArray) - subtract_pullback(dy) = (NoTangent(), dy, -dy) + xproj = ProjectTo(x) + yproj = ProjectTo(y) + function subtract_pullback(dy_raw) + dy = unthunk(dy_raw) # projs will otherwise unthunk twice + (NoTangent(), xproj(dy), yproj(-dy)) + end return x - y, subtract_pullback end From b717c2b2383be46a0e30042294c96a41a06ea4d0 Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Fri, 9 Feb 2024 12:56:27 +0100 Subject: [PATCH 4/5] format --- src/rulesets/Base/arraymath.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 18c70a1c5..765b55b48 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -445,7 +445,7 @@ function rrule(::typeof(-), x::AbstractArray, y::AbstractArray) yproj = ProjectTo(y) function subtract_pullback(dy_raw) dy = unthunk(dy_raw) # projs will otherwise unthunk twice - (NoTangent(), xproj(dy), yproj(-dy)) + return (NoTangent(), xproj(dy), yproj(-dy)) end return x - y, subtract_pullback end From f44437ca6440b5f0effcf9e64d4efcf9e301294c Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Fri, 9 Feb 2024 17:36:46 +0100 Subject: [PATCH 5/5] test projectto --- test/rulesets/Base/arraymath.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index 6894911cd..2099b9b3b 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -223,6 +223,7 @@ @gpu test_frule(-, randn(2), randn(2)) # rev @gpu test_rrule(-, randn(4, 4), randn(4, 4)) + @gpu test_rrule(-, randn(4), randn(ComplexF64, 4)) @gpu test_rrule(-, randn(3), randn(3, 1)) end end