Skip to content

Commit

Permalink
Add type-parameter checks
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Oct 18, 2024
1 parent 96b97e8 commit ecedb09
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
24 changes: 14 additions & 10 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,26 @@ struct MulAddMul{ais1, bis0, TA, TB}
beta::TB
end

function MulAddMul{ais1,bis0}(alpha::TA, beta::TB) where {ais1,bis0,TA,TB}
@noinline throw_alpha(ais1, alpha) = throw(ArgumentError(lazy"alpha = $alpha is inconsistent with the type parameter ais1 = $ais1"))
@noinline throw_beta(bis0, beta) = throw(ArgumentError(lazy"beta = $beta is inconsistent with the type parameter bis0 = $bis0"))
@inline function MulAddMul{ais1,bis0}(alpha::TA, beta::TB) where {ais1,bis0,TA,TB}
xor(ais1, isone(alpha)) && throw_alpha(ais1, alpha)
xor(bis0, iszero(beta)) && throw_beta(bis0, beta)
MulAddMul{ais1,bis0,TA,TB}(alpha,beta)
end

@inline function MulAddMul(alpha, beta)
@inline function MulAddMul(alpha::TA, beta::TB) where {TA,TB}
if isone(alpha)
if iszero(beta)
return MulAddMul{true,true}(alpha, beta)
return MulAddMul{true,true,TA,TB}(alpha, beta)
else
return MulAddMul{true,false}(alpha, beta)
return MulAddMul{true,false,TA,TB}(alpha, beta)
end
else
if iszero(beta)
return MulAddMul{false,true}(alpha, beta)
return MulAddMul{false,true,TA,TB}(alpha, beta)
else
return MulAddMul{false,false}(alpha, beta)
return MulAddMul{false,false,TA,TB}(alpha, beta)
end
end
end
Expand Down Expand Up @@ -87,16 +91,16 @@ macro stable_muladdmul(expr)
local bsym = e.args[3]

local e_sub11 = copy(expr)
e_sub11.args[i] = :(MulAddMul{true, true}($asym, $bsym))
e_sub11.args[i] = :(MulAddMul{true, true, typeof($asym), typeof($bsym)}($asym, $bsym))

local e_sub10 = copy(expr)
e_sub10.args[i] = :(MulAddMul{true, false}($asym, $bsym))
e_sub10.args[i] = :(MulAddMul{true, false, typeof($asym), typeof($bsym)}($asym, $bsym))

local e_sub01 = copy(expr)
e_sub01.args[i] = :(MulAddMul{false, true}($asym, $bsym))
e_sub01.args[i] = :(MulAddMul{false, true, typeof($asym), typeof($bsym)}($asym, $bsym))

local e_sub00 = copy(expr)
e_sub00.args[i] = :(MulAddMul{false, false}($asym, $bsym))
e_sub00.args[i] = :(MulAddMul{false, false, typeof($asym), typeof($bsym)}($asym, $bsym))

local e_out = quote
if isone($asym)
Expand Down
12 changes: 12 additions & 0 deletions stdlib/LinearAlgebra/test/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
module TestGeneric

using Test, LinearAlgebra, Random
using LinearAlgebra: MulAddMul

const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")

Expand All @@ -20,6 +21,17 @@ using .Main.FillArrays

Random.seed!(123)

@testset "MulAddMul" begin
@test_throws "alpha = 2.0 is inconsistent" MulAddMul{true,true}(2.0, 2.0)
@test_throws "alpha = 2.0 is inconsistent" MulAddMul{true,false}(2.0, 2.0)
@test_throws "alpha = true is inconsistent" MulAddMul{false,true}(true, 2.0)
@test_throws "alpha = true is inconsistent" MulAddMul{false,false}(true, 2.0)
@test_throws "beta = 2.0 is inconsistent" MulAddMul{true,true}(true, 2.0)
@test_throws "beta = 2.0 is inconsistent" MulAddMul{false,true}(2.0, 2.0)
@test_throws "beta = false is inconsistent" MulAddMul{true,false}(true, false)
@test_throws "beta = false is inconsistent" MulAddMul{false,false}(2.0, false)
end

n = 5 # should be odd

@testset for elty in (Int, Rational{BigInt}, Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFloat})
Expand Down

0 comments on commit ecedb09

Please sign in to comment.