diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index f1409515c..5878b93b4 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -46,9 +46,9 @@ end # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/411 function rrule( ::typeof(*), - A::StridedMatrix{<:CommutativeMulNumber}, - B::StridedVecOrMat{<:CommutativeMulNumber}, -) + A::StridedMatrix{T}, + B::StridedVecOrMat{T}, +) where {T<:CommutativeMulNumber} function times_pullback(ȳ) Ȳ = unthunk(ȳ) dA = InplaceableThunk(