From 060bf95a78f33cd5f29edd4d3e54e65009fa5bec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20G=C3=B6ttgens?= Date: Fri, 18 Aug 2023 20:54:21 +0200 Subject: [PATCH] Add `setindex!` for `MatElem` with linear indexing (#1407) --- src/Matrix.jl | 23 +++++++++++++++++++++++ src/generic/Matrix.jl | 17 ----------------- test/generic/Matrix-test.jl | 14 ++++++++++++++ 3 files changed, 37 insertions(+), 17 deletions(-) diff --git a/src/Matrix.jl b/src/Matrix.jl index 6a18713a48..06ace4281d 100644 --- a/src/Matrix.jl +++ b/src/Matrix.jl @@ -440,6 +440,29 @@ Base.@propagate_inbounds function Base.setindex!(a::MatrixElem{T}, x, I::Cartesi a end +# linear indexing for row- or column- vectors +Base.@propagate_inbounds function getindex(M::MatrixElem, i::Integer) + if nrows(M) == 1 + M[1, i] + elseif ncols(M) == 1 + M[i, 1] + else + throw(ArgumentError("linear indexing not supported for non-vector matrices")) + end +end + +Base.@propagate_inbounds function setindex!(M::MatrixElem, x, i::Integer) + if nrows(M) == 1 + M[1, i] = x + return M + elseif ncols(M) == 1 + M[i, 1] = x + return M + else + throw(ArgumentError("linear indexing not supported for non-vector matrices")) + end +end + # iteration function Base.iterate(a::MatrixElem{T}, ij=(0, 1)) where T <: NCRingElement diff --git a/src/generic/Matrix.jl b/src/generic/Matrix.jl index 79796c66a6..cbee44fb6b 100644 --- a/src/generic/Matrix.jl +++ b/src/generic/Matrix.jl @@ -100,23 +100,6 @@ function deepcopy_internal(d::MatSpaceView{T}, dict::IdDict) where T <: NCRingEl return MatSpaceView(deepcopy_internal(d.entries, dict), d.base_ring) end -############################################################################### -# -# getindex -# -############################################################################### - -# linear indexing for row- or column- vectors -Base.@propagate_inbounds function getindex(M::MatElem, x::Integer) - if nrows(M) == 1 - M[1, x] - elseif ncols(M) == 1 - M[x, 1] - else - throw(ArgumentError("linear indexing not supported for non-vector matrices")) - end -end - function Base.view(M::Mat{T}, rows::AbstractUnitRange{Int}, cols::AbstractUnitRange{Int}) where T <: NCRingElement return MatSpaceView(view(M.entries, rows, cols), M.base_ring) end diff --git a/test/generic/Matrix-test.jl b/test/generic/Matrix-test.jl index 0d7a615d65..f685826cd4 100644 --- a/test/generic/Matrix-test.jl +++ b/test/generic/Matrix-test.jl @@ -747,6 +747,20 @@ end elseif length(A) >= 1 @test_throws ArgumentError A[1] end + A = deepcopy(A) + if nr == 1 + c = rand(1:nc) + d = rand(1:nc) + A[c] = A[d] + @test A[c] == A[1, d] + elseif nc == 1 + r = rand(1:nr) + s = rand(1:nr) + A[r] = A[s] + @test A[r] == A[s, 1] + elseif length(A) >= 1 + @test_throws ArgumentError (A[1] = zero(base_ring(A))) + end end for (_, (R, rand_params)) in RINGS