diff --git a/src/linear.jl b/src/linear.jl index 080f67b..c918837 100644 --- a/src/linear.jl +++ b/src/linear.jl @@ -66,12 +66,13 @@ This version is for linear equations Ay = b - `A::matrix`, `b::vector`: components of linear system ``A y = b`` - `lsolve::function`: lsolve(A, b). Function to solve the linear system, default is backslash operator. - `Af::factorization`: An optional factorization of A, useful to override default factorize, or if multiple linear solves will be performed with same A matrix. +- `mmul`: Function to compute ``A*y`` for a vector ``y``. Defaults to the julia multipy operator. """ -implicit_linear(A, b; lsolve=linear_solve, Af=nothing) = _implicit_linear(A, b, lsolve, Af) +implicit_linear(A, b; lsolve=linear_solve, Af=nothing, mmul=*) = _implicit_linear(A, b, lsolve, mmul, Af) # If no AD, just solve normally. -_implicit_linear(A, b, lsolve, Af) = isnothing(Af) ? lsolve(A, b) : lsolve(Af, b) +_implicit_linear(A, b, lsolve, mmul, Af) = isnothing(Af) ? lsolve(A, b) : lsolve(Af, b) # function implicit_linear(A, b, lsolve, Af, cache) # if isnothing(cache) # if isnothing(Af) @@ -90,9 +91,9 @@ _implicit_linear(A, b, lsolve, Af) = isnothing(Af) ? lsolve(A, b) : lsolve(Af, b # end # catch three cases where one or both contain duals -_implicit_linear(A::AbstractArray{<:ForwardDiff.Dual{T}}, b::AbstractArray{<:ForwardDiff.Dual{T}}, lsolve, Af) where {T} = linear_dual(A, b, lsolve, Af, T) -_implicit_linear(A, b::AbstractArray{<:ForwardDiff.Dual{T}}, lsolve, Af) where {T} = linear_dual(A, b, lsolve, Af, T) -_implicit_linear(A::AbstractArray{<:ForwardDiff.Dual{T}}, b, lsolve, Af) where {T} = linear_dual(A, b, lsolve, Af, T) +_implicit_linear(A::AbstractArray{<:ForwardDiff.Dual{T}}, b::AbstractArray{<:ForwardDiff.Dual{T}}, lsolve, mmul, Af) where {T} = linear_dual(A, b, lsolve, mmul, Af, T) +_implicit_linear(A, b::AbstractArray{<:ForwardDiff.Dual{T}}, lsolve, mmul, Af) where {T} = linear_dual(A, b, lsolve, mmul, Af, T) +_implicit_linear(A::AbstractArray{<:ForwardDiff.Dual{T}}, b, lsolve, mmul, Af) where {T} = linear_dual(A, b, lsolve, mmul, Af, T) # implicit_linear(A::AbstractArray{<:ForwardDiff.Dual{T}}, b::AbstractArray{<:ForwardDiff.Dual{T}}, lsolve, Af, cache) where {T} = isnothing(cache) ? linear_dual(A, b, lsolve, Af, T) : linear_dual(A, b, lsolve, Af, T, cache) # implicit_linear(A, b::AbstractArray{<:ForwardDiff.Dual{T}}, lsolve, Af, cache) where {T} = isnothing(cache) ? linear_dual(A, b, lsolve, Af, T) : linear_dual(A, b, lsolve, Af, T, cache) # implicit_linear(A::AbstractArray{<:ForwardDiff.Dual{T}}, b, lsolve, Af, cache) where {T} = isnothing(cache) ? linear_dual(A, b, lsolve, Af, T) : linear_dual(A, b, lsolve, Af, T, cache) @@ -110,7 +111,7 @@ _implicit_linear(A::AbstractArray{<:ForwardDiff.Dual{T}}, b, lsolve, Af) where { # implicit_linear!(ydot, A::AbstractArray{<:ForwardDiff.Dual{T}}, b, lsolve, Af) where {T} = linear_dual!(ydot, A, b, lsolve, Af, T) # Both A and b contain duals -function linear_dual(A, b, lsolve, Af, T) +function linear_dual(A, b, lsolve, mmul, Af, T) # unpack dual numbers (if not dual numbers, since only one might be, just returns itself) bv = fd_value(b) @@ -122,7 +123,7 @@ function linear_dual(A, b, lsolve, Af, T) yv = lsolve(Afact, bv) # extract Partials of b - A * y i.e., bdot - Adot * y (since y does not contain duals) - rhs = fd_partials(b - A*yv) + rhs = fd_partials(b - mmul(A,yv)) # solve for new derivatives ydot = lsolve(Afact, rhs) @@ -182,7 +183,7 @@ end # Provide a ChainRule rule for reverse mode -function ChainRulesCore.rrule(::typeof(_implicit_linear), A, b, lsolve, Af) +function ChainRulesCore.rrule(::typeof(_implicit_linear), A, b, lsolve, mmul, Af) # save factorization Afact = isnothing(Af) ? factorize(ReverseDiff.value(A)) : Af @@ -192,16 +193,16 @@ function ChainRulesCore.rrule(::typeof(_implicit_linear), A, b, lsolve, Af) function implicit_pullback(ybar) u = lsolve(Afact', ybar) - return NoTangent(), -u*y', u, NoTangent(), NoTangent() + return NoTangent(), -u*y', u, NoTangent(), NoTangent(), NoTangent() end return y, implicit_pullback end # register above rule for ReverseDiff -ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, b, lsolve, Af) -ReverseDiff.@grad_from_chainrules _implicit_linear(A, b::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, lsolve, Af) -ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, b::Union{ReverseDiff.TrackedArray, AbstractVector{<:ReverseDiff.TrackedReal}}, lsolve, Af) +ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, b, lsolve, mmul, Af) +ReverseDiff.@grad_from_chainrules _implicit_linear(A, b::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, lsolve, mmul, Af) +ReverseDiff.@grad_from_chainrules _implicit_linear(A::Union{ReverseDiff.TrackedArray, AbstractArray{<:ReverseDiff.TrackedReal}}, b::Union{ReverseDiff.TrackedArray, AbstractVector{<:ReverseDiff.TrackedReal}}, lsolve, mmul, Af) # function implicit_linear_inplace(A, b, y, Af) diff --git a/test/runtests.jl b/test/runtests.jl index 30d1457..0aad926 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -292,6 +292,69 @@ end @test all(isapprox.(J1, J2, atol=3e-12)) end +@testset "linear_user_mmul" begin + + count = 0 + + function my_multiply(A, x) + (m, n) = size(A) + T = promote_type(eltype(A), eltype(x)) + y = zeros(T, m) + for j in 1:n + for i in 1:m + y[i] += A[i, j] * x[j] + end + end + # Provide a way to make sure this function was called + count += 1 + return y + end + + function test(a) + A = a[1] * [1.0 2.0 3.0; 4.1 5.3 6.4; 7.4 8.6 9.7] + b = 2.0 * a[2:4] + x = implicit_linear(A, b; mmul=my_multiply) + z = 2 * x + return z + end + + function test2(a) + A = [1.0 2.0 3.0; 4.1 5.3 6.4; 7.4 8.6 9.7] + b = 2.0 * a[2:4] + x = implicit_linear(A, b; mmul=my_multiply) + z = 2 * x + return z + end + + function test3(a) + A = a[1] * [1.0 2.0 3.0; 4.1 5.3 6.4; 7.4 8.6 9.7] + b = 2.0 * ones(3) + x = implicit_linear(A, b; mmul=my_multiply) + z = 2 * x + return z + end + + a = [1.2, 2.3, 3.1, 4.3] + J1 = ForwardDiff.jacobian(test, a) + J2 = ReverseDiff.jacobian(test, a) + + @test count == 1 + @test all(isapprox.(J1, J2, atol=3e-12)) + + J1 = ForwardDiff.jacobian(test2, a) + J2 = ReverseDiff.jacobian(test2, a) + + @test count == 2 + @test all(isapprox.(J1, J2, atol=3e-12)) + + J1 = ForwardDiff.jacobian(test3, a) + J2 = ReverseDiff.jacobian(test3, a) + + @test count == 3 + @test all(isapprox.(J1, J2, atol=3e-12)) + +end + @testset "1d (also parameters)" begin residual(y, x, p) = y/x[1] + x[2]*cos(y) @@ -991,4 +1054,4 @@ end @test all(isapprox.(J1, J2, atol=1e-15)) @test all(isapprox.(J1, Jfd, atol=1e-9)) -end \ No newline at end of file +end