Skip to content

Commit

Permalink
Correctly set zeros with fill!(::SubArray) and fix its return value (
Browse files Browse the repository at this point in the history
  • Loading branch information
sunoru authored Sep 5, 2023
1 parent 03ed9e3 commit 2fae1a1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
5 changes: 3 additions & 2 deletions src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3179,7 +3179,8 @@ function Base.fill!(V::SubArray{Tv, <:Any, <:AbstractSparseMatrixCSC{Tv}, <:Tupl
else
_spsetnz_setindex!(A, convert(Tv, x), I, J)
end
return _checkbuffers(A)
_checkbuffers(A)
V
end
"""
Helper method for immediately preceding fill! method. For all (i,j) such that i in I and
Expand Down Expand Up @@ -3207,7 +3208,7 @@ function _spsetz_setindex!(A::AbstractSparseMatrixCSC,
kI > lengthI && break
entrykIrow = I[kI]
else # entrykArow == entrykIrow
nonzeros(A)[kA] = 0
nonzeros(A)[kA] = zero(eltype(A))
kA += 1
kI += 1
(kA > coljAlastk || kI > lengthI) && break
Expand Down
17 changes: 15 additions & 2 deletions test/sparsematrix_constructors_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ end
A = rand(5,5)
= similar(A)
Ac = copyto!(A, B)
@test Ac === A
@test Ac === A
@test A == copyto!(A´, Matrix(B))
# Test copyto!(dense, Rdest, sparse, Rsrc)
A = rand(5,5)
Expand Down Expand Up @@ -1363,7 +1363,9 @@ end
a = sprand(10, 10, 0.2)
b = copy(a)
sa = view(a, 1:10, 2:3)
fill!(sa, 0.0)
sa_filled = fill!(sa, 0.0)
# `fill!` should return the sub array instead of its parent.
@test sa_filled === sa
b[1:10, 2:3] .= 0.0
@test a == b
A = sparse([1], [1], [Vector{Float64}(undef, 3)], 3, 3)
Expand All @@ -1375,6 +1377,17 @@ end
B[1, jj] = [4.0, 5.0, 6.0]
end
@test A == B

# https://github.com/JuliaSparse/SparseArrays.jl/pull/433
struct Foo
x::Int
end
Base.zero(::Type{Foo}) = Foo(0)
Base.zero(::Foo) = zero(Foo)
C = sparse([1], [1], [Foo(3)], 3, 3)
sC = view(C, 1:1, 1:2)
fill!(sC, zero(Foo))
@test C[1:1, 1:2] == zeros(Foo, 1, 2)
end

using Base: swaprows!, swapcols!
Expand Down

0 comments on commit 2fae1a1

Please sign in to comment.