Skip to content

Commit

Permalink
Fix multiplying a triangular matrix and a Diagonal
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Oct 21, 2024
1 parent cba1cc0 commit 786ba2f
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 64 deletions.
2 changes: 2 additions & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,8 @@ matprod_dest(A::StructuredMatrix, B::Diagonal, TS) = _matprod_dest_diag(A, TS)
matprod_dest(A::Diagonal, B::StructuredMatrix, TS) = _matprod_dest_diag(B, TS)
matprod_dest(A::Diagonal, B::Diagonal, TS) = _matprod_dest_diag(B, TS)
_matprod_dest_diag(A, TS) = similar(A, TS)
_matprod_dest_diag(A::UnitUpperTriangular, TS) = UpperTriangular(similar(parent(A), TS))
_matprod_dest_diag(A::UnitLowerTriangular, TS) = LowerTriangular(similar(parent(A), TS))
function _matprod_dest_diag(A::SymTridiagonal, TS)
n = size(A, 1)
ev = similar(A, TS, max(0, n-1))
Expand Down
190 changes: 127 additions & 63 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -396,82 +396,156 @@ function lmul!(D::Diagonal, T::Tridiagonal)
return T
end

function __muldiag!(out, D::Diagonal, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, _add::MulAddMul)
@inbounds for j in axes(B, 2)
@simd for i in axes(B, 1)
_modify!(_add, D.diag[i] * B[i,j], out, (i,j))
end
end
out
end
_has_matching_storage(out::UpperOrUnitUpperTriangular, A::UpperOrUnitUpperTriangular) = true
_has_matching_storage(out::LowerOrUnitLowerTriangular, A::LowerOrUnitLowerTriangular) = true
_has_matching_storage(out, A) = false
function _rowrange_tri_stored(B::UpperOrUnitUpperTriangular, col)
isunit = B isa UnitUpperTriangular
1:min(col-isunit, size(B,1))
end
function _rowrange_tri_stored(B::LowerOrUnitLowerTriangular, col)
isunit = B isa UnitLowerTriangular
col+isunit:size(B,1)
end
_rowrange_tri_nonstored(B::UpperOrUnitUpperTriangular, col) = col+1:size(B,1)
_rowrange_tri_nonstored(B::LowerOrUnitLowerTriangular, col) = 1:col-1
function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _add::MulAddMul)
isunit = B isa UnitUpperOrUnitLowerTriangular
out_maybeparent, B_maybeparent = _has_matching_storage(out, B) ? (parent(out), parent(B)) : (out, B)
for j in axes(B, 2)
# store the diagonal separately for unit triangular matrices
if isunit
@inbounds _modify!(_add, D.diag[j] * B[j,j], out, (j,j))
end
# indices of out corresponding to the stored indices of B
rowrange = _rowrange_tri_stored(B, j)
@inbounds @simd for i in rowrange
_modify!(_add, D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
end
# indices of out corresponding to the zeros of B
# we only fill these if out and B don't have matching zeros
if !_has_matching_storage(out, B)
rowrange = _rowrange_tri_nonstored(B, j)
if haszero(eltype(out))
_rmul_or_fill!(@view(out[rowrange,j]), _add.beta)
else
@inbounds @simd for i in rowrange
_modify!(_add, D.diag[i] * B[i,j], out, (i,j))
end
end
end
end
out
end
function __muldiag!(out, D::Diagonal, B, _add::MulAddMul)
require_one_based_indexing(out, B)
alpha, beta = _add.alpha, _add.beta
if iszero(alpha)
_rmul_or_fill!(out, beta)
else
if bis0
@inbounds for j in axes(B, 2)
@simd for i in axes(B, 1)
out[i,j] = D.diag[i] * B[i,j] * alpha
end
end
else
@inbounds for j in axes(B, 2)
@simd for i in axes(B, 1)
out[i,j] = D.diag[i] * B[i,j] * alpha + out[i,j] * beta
__muldiag_nonzeroalpha!(out, D, B, _add)
end
return out
end

@inline function __muldiag_nonzeroalpha!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
beta = _add.beta
_add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
@inbounds for j in axes(A, 2)
dja = _add(D.diag[j])
@simd for i in axes(A, 1)
_modify!(_add_aisone, A[i,j] * dja, out, (i,j))
end
end
out
end
function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
isunit = A isa UnitUpperOrUnitLowerTriangular
beta = _add.beta
# since alpha is multiplied to the diagonal element of D,
# we may skip alpha in the second multiplication by setting ais1 to true
_add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
# if both A and out have the same upper/lower triangular structure,
# we may directly read and write from the parents
out_maybeparent, A_maybeparent = _has_matching_storage(out, A) ? (parent(out), parent(A)) : (out, A)
for j in axes(A, 2)
dja = _add(@inbounds D.diag[j])
# store the diagonal separately for unit triangular matrices
if isunit
@inbounds _modify!(_add_aisone, A[j,j] * dja, out, (j,j))
end
# indices of out corresponding to the stored indices of A
rowrange = _rowrange_tri_stored(A, j)
@inbounds @simd for i in rowrange
_modify!(_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
end
# indices of out corresponding to the zeros of A
# we only fill these if out and A don't have matching zeros
if !_has_matching_storage(out, A)
rowrange = _rowrange_tri_nonstored(A, j)
if haszero(eltype(out))
_rmul_or_fill!(@view(out[rowrange,j]), _add.beta)
else
@inbounds @simd for i in rowrange
_modify!(_add, A[i,j] * dja, out, (i,j))
end
end
end
end
return out
out
end
function __muldiag!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
function __muldiag!(out, A, D::Diagonal, _add::MulAddMul)
require_one_based_indexing(out, A)
alpha, beta = _add.alpha, _add.beta
if iszero(alpha)
_rmul_or_fill!(out, beta)
else
if bis0
@inbounds for j in axes(A, 2)
dja = D.diag[j] * alpha
@simd for i in axes(A, 1)
out[i,j] = A[i,j] * dja
end
end
else
@inbounds for j in axes(A, 2)
dja = D.diag[j] * alpha
@simd for i in axes(A, 1)
out[i,j] = A[i,j] * dja + out[i,j] * beta
end
end
end
__muldiag_nonzeroalpha!(out, A, D, _add)
end
return out
end
function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}

@inline function __muldiag_nonzeroalpha!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
d1 = D1.diag
d2 = D2.diag
outd = out.diag
@inbounds @simd for i in eachindex(d1, d2, outd)
_modify!(_add, d1[i] * d2[i], outd, i)
end
out
end
function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
alpha, beta = _add.alpha, _add.beta
if iszero(alpha)
_rmul_or_fill!(out.diag, beta)
else
if bis0
@inbounds @simd for i in eachindex(out.diag)
out.diag[i] = d1[i] * d2[i] * alpha
end
else
@inbounds @simd for i in eachindex(out.diag)
out.diag[i] = d1[i] * d2[i] * alpha + out.diag[i] * beta
end
end
__muldiag_nonzeroalpha!(out, D1, D2, _add)
end
return out
end
function __muldiag!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
require_one_based_indexing(out)
alpha, beta = _add.alpha, _add.beta
mA = size(D1, 1)
@inline function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
d1 = D1.diag
d2 = D2.diag
@inbounds @simd for i in eachindex(d1, d2)
_modify!(_add, d1[i] * d2[i], out, (i,i))
end
out
end
function __muldiag!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1}) where {ais1}
require_one_based_indexing(out)
alpha, beta = _add.alpha, _add.beta
_rmul_or_fill!(out, beta)
if !iszero(alpha)
@inbounds @simd for i in 1:mA
out[i,i] += d1[i] * d2[i] * alpha
end
_add_bis1 = MulAddMul{ais1,false,typeof(alpha),Bool}(alpha,true)
__muldiag_nonzeroalpha!(out, D1, D2, _add_bis1)
end
return out
end
Expand Down Expand Up @@ -658,31 +732,21 @@ for Tri in (:UpperTriangular, :LowerTriangular)
@eval $fun(A::$Tri, D::Diagonal) = $Tri($fun(A.data, D))
@eval $fun(A::$UTri, D::Diagonal) = $Tri(_setdiag!($fun(A.data, D), $f, D.diag))
end
@eval *(A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) =
@invoke *(A::AbstractMatrix, D::Diagonal)
@eval *(A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) =
@invoke *(A::AbstractMatrix, D::Diagonal)
for (fun, f) in zip((:*, :lmul!, :ldiv!, :\), (:identity, :identity, :inv, :inv))
@eval $fun(D::Diagonal, A::$Tri) = $Tri($fun(D, A.data))
@eval $fun(D::Diagonal, A::$UTri) = $Tri(_setdiag!($fun(D, A.data), $f, D.diag))
end
@eval *(D::Diagonal, A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}) =
@invoke *(D::Diagonal, A::AbstractMatrix)
@eval *(D::Diagonal, A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}) =
@invoke *(D::Diagonal, A::AbstractMatrix)
# 3-arg ldiv!
@eval ldiv!(C::$Tri, D::Diagonal, A::$Tri) = $Tri(ldiv!(C.data, D, A.data))
@eval ldiv!(C::$Tri, D::Diagonal, A::$UTri) = $Tri(_setdiag!(ldiv!(C.data, D, A.data), inv, D.diag))
# 3-arg mul! is disambiguated in special.jl
# 5-arg mul!
@eval _mul!(C::$Tri, D::Diagonal, A::$Tri, _add) = $Tri(mul!(C.data, D, A.data, _add.alpha, _add.beta))
@eval function _mul!(C::$Tri, D::Diagonal, A::$UTri, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
α, β = _add.alpha, _add.beta
iszero(α) && return _rmul_or_fill!(C, β)
diag′ = bis0 ? nothing : diag(C)
data = mul!(C.data, D, A.data, α, β)
$Tri(_setdiag!(data, _add, D.diag, diag′))
end
@eval _mul!(C::$Tri, A::$Tri, D::Diagonal, _add) = $Tri(mul!(C.data, A.data, D, _add.alpha, _add.beta))
@eval function _mul!(C::$Tri, A::$UTri, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
α, β = _add.alpha, _add.beta
iszero(α) && return _rmul_or_fill!(C, β)
diag′ = bis0 ? nothing : diag(C)
data = mul!(C.data, A.data, D, α, β)
$Tri(_setdiag!(data, _add, D.diag, diag′))
end
end

@inline function kron!(C::AbstractMatrix, A::Diagonal, B::Diagonal)
Expand Down
22 changes: 22 additions & 0 deletions stdlib/LinearAlgebra/test/addmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,4 +239,26 @@ end
end
end

@testset "Diagonal scaling of a triangular matrix with a non-triangular destination" begin
for MT in (UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriangular)
U = MT(reshape([1:9;],3,3))
M = Array(U)
D = Diagonal(1:3)
A = reshape([1:9;],3,3)
@test mul!(copy(A), U, D, 2, 3) == M * D * 2 + A * 3
@test mul!(copy(A), D, U, 2, 3) == D * M * 2 + A * 3

# nan values with iszero(alpha)
D = Diagonal(fill(NaN,3))
@test mul!(copy(A), U, D, 0, 3) == A * 3
@test mul!(copy(A), D, U, 0, 3) == A * 3

# nan values with iszero(beta)
A = fill(NaN,3,3)
D = Diagonal(1:3)
@test mul!(copy(A), U, D, 2, 0) == M * D * 2
@test mul!(copy(A), D, U, 2, 0) == D * M * 2
end
end

end # module
40 changes: 39 additions & 1 deletion stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1188,7 +1188,7 @@ end
@test oneunit(D3) isa typeof(D3)
end

@testset "AbstractTriangular" for (Tri, UTri) in ((UpperTriangular, UnitUpperTriangular), (LowerTriangular, UnitLowerTriangular))
@testset "$Tri" for (Tri, UTri) in ((UpperTriangular, UnitUpperTriangular), (LowerTriangular, UnitLowerTriangular))
A = randn(4, 4)
TriA = Tri(A)
UTriA = UTri(A)
Expand Down Expand Up @@ -1218,6 +1218,44 @@ end
@test outTri === mul!(outTri, D, UTriA, 2, 1)::Tri == mul!(out, D, Matrix(UTriA), 2, 1)
@test outTri === mul!(outTri, TriA, D, 2, 1)::Tri == mul!(out, Matrix(TriA), D, 2, 1)
@test outTri === mul!(outTri, UTriA, D, 2, 1)::Tri == mul!(out, Matrix(UTriA), D, 2, 1)

# we may write to a Unit triangular if the diagonal is preserved
ID = Diagonal(ones(size(UTriA,2)))
@test mul!(copy(UTriA), UTriA, ID) == UTriA
@test mul!(copy(UTriA), ID, UTriA) == UTriA

@testset "partly filled parents" begin
M = Matrix{BigFloat}(undef, 2, 2)
M[1,1] = M[2,2] = 3
isupper = Tri == UpperTriangular
M[1+!isupper, 1+isupper] = 3
D = Diagonal(1:2)
T = Tri(M)
TA = Array(T)
@test T * D == TA * D
@test D * T == D * TA
@test mul!(copy(T), T, D, 2, 3) == 2T * D + 3T
@test mul!(copy(T), D, T, 2, 3) == 2D * T + 3T

U = UTri(M)
UA = Array(U)
@test U * D == UA * D
@test D * U == D * UA
@test mul!(copy(T), U, D, 2, 3) == 2 * UA * D + 3TA
@test mul!(copy(T), D, U, 2, 3) == 2 * D * UA + 3TA

M2 = Matrix{BigFloat}(undef, 2, 2)
M2[1+!isupper, 1+isupper] = 3
U = UTri(M2)
UA = Array(U)
@test U * D == UA * D
@test D * U == D * UA
ID = Diagonal(ones(size(U,2)))
@test mul!(copy(U), U, ID) == U
@test mul!(copy(U), ID, U) == U
@test mul!(copy(U), U, ID, 2, -1) == U
@test mul!(copy(U), ID, U, 2, -1) == U
end
end

struct SMatrix1{T} <: AbstractArray{T,2}
Expand Down

0 comments on commit 786ba2f

Please sign in to comment.