From d07a863ce4ffe6f33ae2f241ac759ea19a2ce044 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Wed, 29 May 2024 19:15:50 +0200 Subject: [PATCH] Accomodate for rectangular matrices in `copytrito!` (#54587) (cherry picked from commit fc54be6eac4409afd831b37d56d4a7796fdc3565) --- stdlib/LinearAlgebra/src/generic.jl | 25 ++++++++----- stdlib/LinearAlgebra/src/lapack.jl | 20 ++++++++-- stdlib/LinearAlgebra/test/generic.jl | 56 +++++++++++++++++++++++++--- stdlib/LinearAlgebra/test/lapack.jl | 20 +++++++++- 4 files changed, 101 insertions(+), 20 deletions(-) diff --git a/stdlib/LinearAlgebra/src/generic.jl b/stdlib/LinearAlgebra/src/generic.jl index 81d092ca14060..c2144bf85d024 100644 --- a/stdlib/LinearAlgebra/src/generic.jl +++ b/stdlib/LinearAlgebra/src/generic.jl @@ -1934,19 +1934,24 @@ function copytrito!(B::AbstractMatrix, A::AbstractMatrix, uplo::AbstractChar) BLAS.chkuplo(uplo) m,n = size(A) m1,n1 = size(B) - (m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least the same number of rows and columns than A of size ($m,$n)")) A = Base.unalias(B, A) if uplo == 'U' - for j=1:n - for i=1:min(j,m) - @inbounds B[i,j] = A[i,j] - end + if n < m + (m1 < n || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($n,$n)")) + else + (m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)")) end - else # uplo == 'L' - for j=1:n - for i=j:m - @inbounds B[i,j] = A[i,j] - end + for j in 1:n, i in 1:min(j,m) + @inbounds B[i,j] = A[i,j] + end + else # uplo == 'L' + if m < n + (m1 < m || n1 < m) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$m)")) + else + (m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)")) + end + for j in 1:n, i in j:m + @inbounds B[i,j] = A[i,j] end end return B diff --git a/stdlib/LinearAlgebra/src/lapack.jl b/stdlib/LinearAlgebra/src/lapack.jl index 254425f8cd2bf..4b7bffcd50e89 100644 --- a/stdlib/LinearAlgebra/src/lapack.jl +++ b/stdlib/LinearAlgebra/src/lapack.jl @@ -7163,9 +7163,23 @@ for (fn, elty) in ((:dlacpy_, :Float64), function lacpy!(B::AbstractMatrix{$elty}, A::AbstractMatrix{$elty}, uplo::AbstractChar) require_one_based_indexing(A, B) chkstride1(A, B) - m,n = size(A) - m1,n1 = size(B) - (m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least the same number of rows and columns than A of size ($m,$n)")) + m, n = size(A) + m1, n1 = size(B) + if uplo == 'U' + if n < m + (m1 < n || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($n,$n)")) + else + (m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)")) + end + elseif uplo == 'L' + if m < n + (m1 < m || n1 < m) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$m)")) + else + (m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)")) + end + else + (m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)")) + end lda = max(1, stride(A, 2)) ldb = max(1, stride(B, 2)) ccall((@blasfunc($fn), libblastrampoline), Cvoid, diff --git a/stdlib/LinearAlgebra/test/generic.jl b/stdlib/LinearAlgebra/test/generic.jl index ba4bdb1845255..fd464d6c0762c 100644 --- a/stdlib/LinearAlgebra/test/generic.jl +++ b/stdlib/LinearAlgebra/test/generic.jl @@ -647,12 +647,56 @@ end @testset "copytrito!" begin n = 10 - A = rand(n, n) - for uplo in ('L', 'U') - B = zeros(n, n) - copytrito!(B, A, uplo) - C = uplo == 'L' ? tril(A) : triu(A) - @test B ≈ C + @testset "square" begin + for A in (rand(n, n), rand(Int8, n, n)), uplo in ('L', 'U') + for AA in (A, view(A, reverse.(axes(A))...)) + C = uplo == 'L' ? tril(AA) : triu(AA) + for B in (zeros(n, n), zeros(n+1, n+2)) + copytrito!(B, AA, uplo) + @test view(B, 1:n, 1:n) == C + end + end + end + end + @testset "wide" begin + for A in (rand(n, 2n), rand(Int8, n, 2n)) + for AA in (A, view(A, reverse.(axes(A))...)) + C = tril(AA) + for (M, N) in ((n, n), (n+1, n), (n, n+1), (n+1, n+1)) + B = zeros(M, N) + copytrito!(B, AA, 'L') + @test view(B, 1:n, 1:n) == view(C, 1:n, 1:n) + end + @test_throws DimensionMismatch copytrito!(zeros(n-1, 2n), AA, 'L') + C = triu(AA) + for (M, N) in ((n, 2n), (n+1, 2n), (n, 2n+1), (n+1, 2n+1)) + B = zeros(M, N) + copytrito!(B, AA, 'U') + @test view(B, 1:n, 1:2n) == view(C, 1:n, 1:2n) + end + @test_throws DimensionMismatch copytrito!(zeros(n+1, 2n-1), AA, 'U') + end + end + end + @testset "tall" begin + for A in (rand(2n, n), rand(Int8, 2n, n)) + for AA in (A, view(A, reverse.(axes(A))...)) + C = triu(AA) + for (M, N) in ((n, n), (n+1, n), (n, n+1), (n+1, n+1)) + B = zeros(M, N) + copytrito!(B, AA, 'U') + @test view(B, 1:n, 1:n) == view(C, 1:n, 1:n) + end + @test_throws DimensionMismatch copytrito!(zeros(n-1, n+1), AA, 'U') + C = tril(AA) + for (M, N) in ((2n, n), (2n, n+1), (2n+1, n), (2n+1, n+1)) + B = zeros(M, N) + copytrito!(B, AA, 'L') + @test view(B, 1:2n, 1:n) == view(C, 1:2n, 1:n) + end + @test_throws DimensionMismatch copytrito!(zeros(n-1, n+1), AA, 'L') + end + end end @testset "aliasing" begin M = Matrix(reshape(1:36, 6, 6)) diff --git a/stdlib/LinearAlgebra/test/lapack.jl b/stdlib/LinearAlgebra/test/lapack.jl index 652c6c2e27e6c..000438a004b23 100644 --- a/stdlib/LinearAlgebra/test/lapack.jl +++ b/stdlib/LinearAlgebra/test/lapack.jl @@ -805,8 +805,26 @@ end B = zeros(elty, n, n) LinearAlgebra.LAPACK.lacpy!(B, A, uplo) C = uplo == 'L' ? tril(A) : (uplo == 'U' ? triu(A) : A) - @test B ≈ C + @test B == C + B = zeros(elty, n+1, n+1) + LinearAlgebra.LAPACK.lacpy!(B, A, uplo) + C = uplo == 'L' ? tril(A) : (uplo == 'U' ? triu(A) : A) + @test view(B, 1:n, 1:n) == C end + A = rand(elty, n, n+1) + B = zeros(elty, n, n) + LinearAlgebra.LAPACK.lacpy!(B, A, 'L') + @test B == view(tril(A), 1:n, 1:n) + B = zeros(elty, n, n+1) + LinearAlgebra.LAPACK.lacpy!(B, A, 'U') + @test B == triu(A) + A = rand(elty, n+1, n) + B = zeros(elty, n, n) + LinearAlgebra.LAPACK.lacpy!(B, A, 'U') + @test B == view(triu(A), 1:n, 1:n) + B = zeros(elty, n+1, n) + LinearAlgebra.LAPACK.lacpy!(B, A, 'L') + @test B == tril(A) end end