Skip to content

Commit

Permalink
Split branches into separate functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Oct 23, 2024
1 parent 00c2e6f commit ff1a06c
Showing 1 changed file with 40 additions and 43 deletions.
83 changes: 40 additions & 43 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -831,73 +831,70 @@ end
# legacy method, retained for backward compatibility
_generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, _add::MulAddMul = MulAddMul()) =
_generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
alpha::Number, beta::Number)
require_one_based_indexing(C, A, B)
@assert tA in ('N', 'T', 'C')
mB = length(B)
mA, nA = lapack_size(tA, A)
if mB != nA
throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), vector B has length $mB"))
end
if mA != length(C)
throw(DimensionMismatch(lazy"result C has length $(length(C)), needs length $mA"))
end

function __generic_matvecmul!(f::F, C::AbstractVector, A::AbstractVecOrMat, B::AbstractVector,
alpha::Number, beta::Number) where {F}
Astride = size(A, 1)

@inbounds begin
if tA == 'T' # fastest case
if nA == 0
for k = 1:mA
@stable_muladdmul _modify!(MulAddMul(alpha,beta), false, C, k)
end
else
for k = 1:mA
aoffs = (k-1)*Astride
firstterm = transpose(A[aoffs + 1])*B[1]
s = zero(firstterm + firstterm)
for i = 1:nA
s += transpose(A[aoffs+i]) * B[i]
end
@stable_muladdmul _modify!(MulAddMul(alpha,beta), s, C, k)
end
end
elseif tA == 'C'
if nA == 0
for k = 1:mA
if length(B) == 0
for k = eachindex(C)
@stable_muladdmul _modify!(MulAddMul(alpha,beta), false, C, k)
end
else
for k = 1:mA
for k = eachindex(C)
aoffs = (k-1)*Astride
firstterm = A[aoffs + 1]'B[1]
firstterm = f(A[aoffs + 1]) * B[1]
s = zero(firstterm + firstterm)
for i = 1:nA
s += A[aoffs + i]'B[i]
for i = eachindex(B)
s += f(A[aoffs+i]) * B[i]
end
@stable_muladdmul _modify!(MulAddMul(alpha,beta), s, C, k)
end
end
else # tA == 'N'
for i = 1:mA
end
end
function __generic_matvecmul!(::typeof(identity), C::AbstractVector, A::AbstractVecOrMat, B::AbstractVector,
alpha::Number, beta::Number)
Astride = size(A, 1)
@inbounds begin
for i = eachindex(C)
if !iszero(beta)
C[i] *= beta
elseif mB == 0
elseif length(B) == 0
C[i] = false
else
C[i] = zero(A[i]*B[1] + A[i]*B[1])
end
end
for k = 1:mB
for k = eachindex(B)
aoffs = (k-1)*Astride
b = @stable_muladdmul MulAddMul(alpha,beta)(B[k])
for i = 1:mA
for i = eachindex(C)
C[i] += A[aoffs + i] * b
end
end
end
end # @inbounds
return C
end
function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
alpha::Number, beta::Number)
require_one_based_indexing(C, A, B)
@assert tA in ('N', 'T', 'C')
mB = length(B)
mA, nA = lapack_size(tA, A)
if mB != nA
throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), vector B has length $mB"))
end
if mA != length(C)
throw(DimensionMismatch(lazy"result C has length $(length(C)), needs length $mA"))
end

if tA == 'T' # fastest case
__generic_matvecmul!(transpose, C, A, B, alpha, beta)
elseif tA == 'C'
__generic_matvecmul!(adjoint, C, A, B, alpha, beta)
else # tA == 'N'
__generic_matvecmul!(identity, C, A, B, alpha, beta)
end
C
end

Expand Down

0 comments on commit ff1a06c

Please sign in to comment.