Skip to content
This repository has been archived by the owner on Dec 7, 2021. It is now read-only.

Commit

Permalink
Fix BlockSparse-DiagBlockSparse contraction bug (#64)
Browse files Browse the repository at this point in the history
* Fix bug in BlockSparseTensor-DiagBlockSparseTensor contractions

* Bump to NDTensors v0.1.24
  • Loading branch information
mtfishman authored Feb 2, 2021
1 parent f5d9733 commit 6d911a6
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 41 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.1.23"
version = "0.1.24"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
63 changes: 31 additions & 32 deletions src/blocksparse/diagblocksparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ is known already).
"""
blockview(T::DiagBlockSparseTensor, blockT::Block) = blockview(T, BlockOffset(blockT, offset(T, blockT)))

getindex(T::DiagBlockSparseTensor, block::Block) = blockview(T, block)

function blockview(T::DiagBlockSparseTensor,
bof::BlockOffset)
blockT,offsetT = bof
Expand Down Expand Up @@ -522,49 +524,46 @@ function contraction_output(T1::TensorT1,
return R,contraction_plan
end

function contract(T1::BlockSparseTensor,
labelsT1,
T2::DiagBlockSparseTensor,
labelsT2,
function contract(T1::BlockSparseTensor, labelsT1,
T2::DiagBlockSparseTensor, labelsT2,
labelsR = contract_labels(labelsT1,labelsT2))
R,contraction_plan = contraction_output(T1,labelsT1,T2,labelsT2,labelsR)
R = contract!(R,labelsR,T1,labelsT1,T2,labelsT2,contraction_plan)
return R
end

contract(T1::DiagBlockSparseTensor,
labelsT1,
T2::BlockSparseTensor,
labelsT2,
labelsR = contract_labels(labelsT2,labelsT1)) = contract(T2,labelsT2,T1,labelsT1,labelsR)

function contract!(R::BlockSparseTensor,
labelsR,
T1::BlockSparseTensor,
labelsT1,
T2::DiagBlockSparseTensor,
labelsT2,
contraction_plan)
for (pos1,pos2,posR) in contraction_plan
blockT1 = blockview(T1,pos1)
blockT2 = blockview(T2,pos2)
blockR = blockview(R,posR)
contract!(blockR,labelsR,
blockT1,labelsT1,
blockT2,labelsT2)
contract(T1::DiagBlockSparseTensor, labelsT1,
T2::BlockSparseTensor, labelsT2,
labelsR = contract_labels(labelsT2,labelsT1)) =
contract(T2,labelsT2,T1,labelsT1,labelsR)

function contract!(R::BlockSparseTensor{ElR, NR}, labelsR, T1::BlockSparseTensor, labelsT1,
T2::DiagBlockSparseTensor, labelsT2, contraction_plan) where {ElR <: Number, NR}
already_written_to = Dict{Block{NR}, Bool}()
# In R .= α .* (T1 * T2) .+ β .* R
α = one(ElR)
for (block1, block2, blockR) in contraction_plan
T1block = T1[block1]
T2block = T2[block2]
Rblock = R[blockR]
β = one(ElR)
if !haskey(already_written_to, blockR)
already_written_to[blockR] = true
# Overwrite the block of R
β = zero(ElR)
end
contract!(Rblock, labelsR, T1block, labelsT1,
T2block, labelsT2, α, β)
end
return R
end

contract!(C::BlockSparseTensor,Clabels,
A::BlockSparseTensor,Alabels,
B::DiagBlockSparseTensor,Blabels) = contract!(C,Clabels,
B,Blabels,
A,Alabels)
contract!(C::BlockSparseTensor, Clabels,
A::BlockSparseTensor, Alabels,
B::DiagBlockSparseTensor, Blabels) =
contract!(C, Clabels, B,Blabels, A, Alabels)

function show(io::IO,
mime::MIME"text/plain",
T::DiagBlockSparseTensor)
function show(io::IO, mime::MIME"text/plain", T::DiagBlockSparseTensor)
summary(io,T)
for (n, (block, _)) in enumerate(diagblockoffsets(T))
blockdimsT = blockdims(T,block)
Expand Down
27 changes: 19 additions & 8 deletions src/diag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,8 @@ end

function contract!(C::DenseTensor{ElC,NC},Clabels,
A::DiagTensor{ElA,NA},Alabels,
B::DenseTensor{ElB,NB},Blabels;
B::DenseTensor{ElB,NB},Blabels,
α::Number = one(ElC), β::Number = zero(ElC);
convert_to_dense::Bool = true) where {ElA,NA,
ElB,NB,
ElC,NC}
Expand All @@ -411,21 +412,28 @@ function contract!(C::DenseTensor{ElC,NC},Clabels,
if length(Clabels) == 0
# all indices are summed over, just add the product of the diagonal
# elements of A and B
# Assumes C starts set to 0
c₁ = zero(ElC)
for i = 1:min_dim
setdiagindex!(C,getdiagindex(C,1)+getdiagindex(A,i)*getdiagindex(B,i),1)
c₁ += getdiagindex(A, i) * getdiagindex(B, i)
end
setdiagindex!(C, α * c₁ + β * getdiagindex(C, 1), 1)
else
# not all indices are summed over, set the diagonals of the result
# to the product of the diagonals of A and B
# TODO: should we make this return a Diag storage?
for i = 1:min_dim
setdiagindex!(C,getdiagindex(A,i)*getdiagindex(B,i),i)
setdiagindex!(C, α * getdiagindex(A, i) * getdiagindex(B, i) + β * getdiagindex(C, i), i)
end
end
else
# Most general contraction
if convert_to_dense
contract!(C, Clabels, dense(A), Alabels, B, Blabels)
contract!(C, Clabels, dense(A), Alabels, B, Blabels, α, β)
else
if !isone(α) || !iszero(β)
error("`contract!(::DenseTensor, ::DiagTensor, ::DenseTensor, α, β; convert_to_dense = false)` with `α ≠ 1` or `β ≠ 0` is not currently supported. You can call it with `convert_to_dense = true` instead.")
end
astarts = zeros(Int,length(Alabels))
bstart = 0
cstart = 0
Expand Down Expand Up @@ -479,9 +487,12 @@ function contract!(C::DenseTensor{ElC,NC},Clabels,
boffset += ii*bustride[i]
coffset += ii*custride[i]
end
c = zero(ElC)
for j in 1:diaglength(A)
C[cstart+j*c_cstride+coffset] += getdiagindex(A,j)*
B[bstart+j*b_cstride+boffset]
# With α == 0 && β == 1
C[cstart+j*c_cstride+coffset] += getdiagindex(A, j)* B[bstart+j*b_cstride+boffset]
# XXX: not sure if this is correct
#C[cstart+j*c_cstride+coffset] += α * getdiagindex(A, j)* B[bstart+j*b_cstride+boffset] + β * C[cstart+j*c_cstride+coffset]
end
end
end
Expand All @@ -490,8 +501,8 @@ function contract!(C::DenseTensor{ElC,NC},Clabels,
end

contract!(C::DenseTensor, Clabels, A::DenseTensor, Alabels,
B::DiagTensor, Blabels) =
contract!(C, Clabels, B, Blabels, A, Alabels)
B::DiagTensor, Blabels, α::Number = one(eltype(C)), β::Number = zero(eltype(C))) =
contract!(C, Clabels, B, Blabels, A, Alabels, α, β)

function show(io::IO, mime::MIME"text/plain", T::DiagTensor)
summary(io,T)
Expand Down

2 comments on commit 6d911a6

@mtfishman
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/29210

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.24 -m "<description of version>" 6d911a62bc4474b80f02d9f0371e60ed775634e9
git push origin v0.1.24

Please sign in to comment.