From 542abbf291559f5c953b9ee53862a3d318eef0ff Mon Sep 17 00:00:00 2001 From: Jutho Date: Mon, 30 May 2022 15:48:17 +0200 Subject: [PATCH] restrict strided array multiplication rrule Fix the specialized `rrule` for `StridedArray` multiplication to equal `eltype`. Fixes issue [#625](https://github.com/JuliaDiff/ChainRules.jl/issues/625). --- src/rulesets/Base/arraymath.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index f1409515c..5878b93b4 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -46,9 +46,9 @@ end # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/411 function rrule( ::typeof(*), - A::StridedMatrix{<:CommutativeMulNumber}, - B::StridedVecOrMat{<:CommutativeMulNumber}, -) + A::StridedMatrix{T}, + B::StridedVecOrMat{T}, +) where {T<:CommutativeMulNumber} function times_pullback(ȳ) Ȳ = unthunk(ȳ) dA = InplaceableThunk(