From ecedb09258115878b0f172446f2a646335bf2570 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 18 Oct 2024 12:41:58 +0530 Subject: [PATCH] Add type-parameter checks --- stdlib/LinearAlgebra/src/generic.jl | 24 ++++++++++++++---------- stdlib/LinearAlgebra/test/generic.jl | 12 ++++++++++++ 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/stdlib/LinearAlgebra/src/generic.jl b/stdlib/LinearAlgebra/src/generic.jl index 2043dd9149f72..0c52c666293ef 100644 --- a/stdlib/LinearAlgebra/src/generic.jl +++ b/stdlib/LinearAlgebra/src/generic.jl @@ -33,22 +33,26 @@ struct MulAddMul{ais1, bis0, TA, TB} beta::TB end -function MulAddMul{ais1,bis0}(alpha::TA, beta::TB) where {ais1,bis0,TA,TB} +@noinline throw_alpha(ais1, alpha) = throw(ArgumentError(lazy"alpha = $alpha is inconsistent with the type parameter ais1 = $ais1")) +@noinline throw_beta(bis0, beta) = throw(ArgumentError(lazy"beta = $beta is inconsistent with the type parameter bis0 = $bis0")) +@inline function MulAddMul{ais1,bis0}(alpha::TA, beta::TB) where {ais1,bis0,TA,TB} + xor(ais1, isone(alpha)) && throw_alpha(ais1, alpha) + xor(bis0, iszero(beta)) && throw_beta(bis0, beta) MulAddMul{ais1,bis0,TA,TB}(alpha,beta) end -@inline function MulAddMul(alpha, beta) +@inline function MulAddMul(alpha::TA, beta::TB) where {TA,TB} if isone(alpha) if iszero(beta) - return MulAddMul{true,true}(alpha, beta) + return MulAddMul{true,true,TA,TB}(alpha, beta) else - return MulAddMul{true,false}(alpha, beta) + return MulAddMul{true,false,TA,TB}(alpha, beta) end else if iszero(beta) - return MulAddMul{false,true}(alpha, beta) + return MulAddMul{false,true,TA,TB}(alpha, beta) else - return MulAddMul{false,false}(alpha, beta) + return MulAddMul{false,false,TA,TB}(alpha, beta) end end end @@ -87,16 +91,16 @@ macro stable_muladdmul(expr) local bsym = e.args[3] local e_sub11 = copy(expr) - e_sub11.args[i] = :(MulAddMul{true, true}($asym, $bsym)) + e_sub11.args[i] = :(MulAddMul{true, true, typeof($asym), typeof($bsym)}($asym, $bsym)) local e_sub10 = copy(expr) - e_sub10.args[i] = :(MulAddMul{true, false}($asym, $bsym)) + e_sub10.args[i] = :(MulAddMul{true, false, typeof($asym), typeof($bsym)}($asym, $bsym)) local e_sub01 = copy(expr) - e_sub01.args[i] = :(MulAddMul{false, true}($asym, $bsym)) + e_sub01.args[i] = :(MulAddMul{false, true, typeof($asym), typeof($bsym)}($asym, $bsym)) local e_sub00 = copy(expr) - e_sub00.args[i] = :(MulAddMul{false, false}($asym, $bsym)) + e_sub00.args[i] = :(MulAddMul{false, false, typeof($asym), typeof($bsym)}($asym, $bsym)) local e_out = quote if isone($asym) diff --git a/stdlib/LinearAlgebra/test/generic.jl b/stdlib/LinearAlgebra/test/generic.jl index e0a1704913f78..cfc367698172a 100644 --- a/stdlib/LinearAlgebra/test/generic.jl +++ b/stdlib/LinearAlgebra/test/generic.jl @@ -3,6 +3,7 @@ module TestGeneric using Test, LinearAlgebra, Random +using LinearAlgebra: MulAddMul const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test") @@ -20,6 +21,17 @@ using .Main.FillArrays Random.seed!(123) +@testset "MulAddMul" begin + @test_throws "alpha = 2.0 is inconsistent" MulAddMul{true,true}(2.0, 2.0) + @test_throws "alpha = 2.0 is inconsistent" MulAddMul{true,false}(2.0, 2.0) + @test_throws "alpha = true is inconsistent" MulAddMul{false,true}(true, 2.0) + @test_throws "alpha = true is inconsistent" MulAddMul{false,false}(true, 2.0) + @test_throws "beta = 2.0 is inconsistent" MulAddMul{true,true}(true, 2.0) + @test_throws "beta = 2.0 is inconsistent" MulAddMul{false,true}(2.0, 2.0) + @test_throws "beta = false is inconsistent" MulAddMul{true,false}(true, false) + @test_throws "beta = false is inconsistent" MulAddMul{false,false}(2.0, false) +end + n = 5 # should be odd @testset for elty in (Int, Rational{BigInt}, Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFloat})