Skip to content

Commit

Permalink
Add setindex! for MatElem with linear indexing (#1407)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgoettgens authored Aug 18, 2023
1 parent e537fdd commit 060bf95
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 17 deletions.
23 changes: 23 additions & 0 deletions src/Matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,29 @@ Base.@propagate_inbounds function Base.setindex!(a::MatrixElem{T}, x, I::Cartesi
a
end

# linear indexing for row- or column- vectors
Base.@propagate_inbounds function getindex(M::MatrixElem, i::Integer)
if nrows(M) == 1
M[1, i]
elseif ncols(M) == 1
M[i, 1]
else
throw(ArgumentError("linear indexing not supported for non-vector matrices"))
end
end

Base.@propagate_inbounds function setindex!(M::MatrixElem, x, i::Integer)
if nrows(M) == 1
M[1, i] = x
return M
elseif ncols(M) == 1
M[i, 1] = x
return M
else
throw(ArgumentError("linear indexing not supported for non-vector matrices"))
end
end

# iteration

function Base.iterate(a::MatrixElem{T}, ij=(0, 1)) where T <: NCRingElement
Expand Down
17 changes: 0 additions & 17 deletions src/generic/Matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,23 +100,6 @@ function deepcopy_internal(d::MatSpaceView{T}, dict::IdDict) where T <: NCRingEl
return MatSpaceView(deepcopy_internal(d.entries, dict), d.base_ring)
end

###############################################################################
#
# getindex
#
###############################################################################

# linear indexing for row- or column- vectors
Base.@propagate_inbounds function getindex(M::MatElem, x::Integer)
if nrows(M) == 1
M[1, x]
elseif ncols(M) == 1
M[x, 1]
else
throw(ArgumentError("linear indexing not supported for non-vector matrices"))
end
end

function Base.view(M::Mat{T}, rows::AbstractUnitRange{Int}, cols::AbstractUnitRange{Int}) where T <: NCRingElement
return MatSpaceView(view(M.entries, rows, cols), M.base_ring)
end
Expand Down
14 changes: 14 additions & 0 deletions test/generic/Matrix-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,20 @@ end
elseif length(A) >= 1
@test_throws ArgumentError A[1]
end
A = deepcopy(A)
if nr == 1
c = rand(1:nc)
d = rand(1:nc)
A[c] = A[d]
@test A[c] == A[1, d]
elseif nc == 1
r = rand(1:nr)
s = rand(1:nr)
A[r] = A[s]
@test A[r] == A[s, 1]
elseif length(A) >= 1
@test_throws ArgumentError (A[1] = zero(base_ring(A)))
end
end

for (_, (R, rand_params)) in RINGS
Expand Down

0 comments on commit 060bf95

Please sign in to comment.