From a9f1dccfa4ceced5eeed19a27412d4a4f04d2cf1 Mon Sep 17 00:00:00 2001 From: Daniel Ingraham Date: Thu, 20 Jul 2023 13:08:47 -0400 Subject: [PATCH] Missed a few other TrackedArray, TrackedReal references --- src/external.jl | 4 ++-- src/linear.jl | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/external.jl b/src/external.jl index f6592dc..4c87777 100644 --- a/src/external.jl +++ b/src/external.jl @@ -196,5 +196,5 @@ function ChainRulesCore.rrule(::typeof(_provide_rule), func, x, p, mode, jacobia return y, pullback end -ReverseDiff.@grad_from_chainrules _provide_rule(func, x::TrackedArray, p, mode, jacobian, jvp, vjp) -ReverseDiff.@grad_from_chainrules _provide_rule(func, x::AbstractArray{<:TrackedReal}, p, mode, jacobian, jvp, vjp) +ReverseDiff.@grad_from_chainrules _provide_rule(func, x::ReverseDiff.TrackedArray, p, mode, jacobian, jvp, vjp) +ReverseDiff.@grad_from_chainrules _provide_rule(func, x::AbstractArray{<:ReverseDiff.TrackedReal}, p, mode, jacobian, jvp, vjp) diff --git a/src/linear.jl b/src/linear.jl index 4ba4f16..080f67b 100644 --- a/src/linear.jl +++ b/src/linear.jl @@ -199,9 +199,9 @@ function ChainRulesCore.rrule(::typeof(_implicit_linear), A, b, lsolve, Af) end # register above rule for ReverseDiff -ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{TrackedArray, AbstractArray{<:TrackedReal}}, b, lsolve, Af) -ReverseDiff.@grad_from_chainrules _implicit_linear(A, b::Union{TrackedArray, AbstractArray{<:TrackedReal}}, lsolve, Af) -ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{TrackedArray, AbstractArray{<:TrackedReal}}, b::Union{TrackedArray, AbstractVector{<:TrackedReal}}, lsolve, Af) +ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, b, lsolve, Af) +ReverseDiff.@grad_from_chainrules _implicit_linear(A, b::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, lsolve, Af) +ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, b::Union{ReverseDiff.TrackedArray, AbstractVector{<:ReverseDiff.TrackedReal}}, lsolve, Af) # function implicit_linear_inplace(A, b, y, Af)