Skip to content

Commit

Permalink
Work on 3-argument KronTrav (#134)
Browse files Browse the repository at this point in the history
* Work on 3-argument KronTrav

* Update blockkron.jl

* 3-vector KronTrav
  • Loading branch information
dlfivefifty authored Oct 8, 2024
1 parent bf02ac0 commit b71450c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
29 changes: 23 additions & 6 deletions src/blockkron.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ function _diagtravgetindex(::AbstractPaddedLayout{<:AbstractStridedLayout}, A::A
end
end


function _diagtravgetindex(::AbstractStridedLayout, A::AbstractArray{T,3}, K::Block{1}) where T
k = Int(K)
m,n,p = size(A)
Expand All @@ -103,7 +104,7 @@ function _diagtravgetindex(::AbstractStridedLayout, A::AbstractArray{T,3}, K::Bl
st3 = stride(A,3)
ret = T[]
for j = 0:k-1
append!(ret, view(A, range(j*st + k-j; step=st3-st, length=j+1)))
append!(ret, view(A, range(j*st + k-j; step=st3-st, length=j+1))) # this matches lexigraphical order
end
ret
end
Expand Down Expand Up @@ -152,7 +153,7 @@ end
size(A::InvDiagTrav) = (blocksize(A.vector,1),blocksize(A.vector,1))

function getindex(A::InvDiagTrav{T}, k::Int, j::Int) where T
if k+j-1  blocksize(A.vector,1)
if k+j-1 blocksize(A.vector,1)
A.vector[Block(k+j-1)][j]
else
zero(T)
Expand Down Expand Up @@ -185,20 +186,35 @@ KronTrav(A::AbstractArray...) = KronTrav{mapreduce(eltype, promote_type, A)}(A..
copy(K::KronTrav) = KronTrav(map(copy,K.args), K.axes)
axes(A::KronTrav) = A.axes

function getindex(M::KronTrav{<:Any,1}, K::Block{1})
A,B = M.args


function _krontrav_getindex(K::Block{1}, A, B)
m,n = length(A), length(B)
mn = min(m,n)
k = Int(K)
if k  mn
if k mn
A[1:k] .* B[k:-1:1]
elseif m < n
elseif m < n
A .* B[k:-1:(k-m+1)]
else # n < m
A[(k-n+1):k] .* B[end:-1:1]
end
end



function _krontrav_getindex(K::Block{1}, A, B, C)
@assert length(A) == length(B) == length(C) # TODO: generalise

# make a tuple corresponding to lexigraphical order
ret = Vector{promote_type(eltype(A),eltype(B),eltype(C))}()
n = Int(K)
for k = 1:n, j=1:k
push!(ret, C[n-k+1]B[k-j+1]A[j])
end
ret
end

function _krontrav_getindex(K::Block{2}, A, B)
m,n = size(A), size(B)
@assert m == n
Expand All @@ -219,6 +235,7 @@ function _krontrav_getindex(Kin::Block{2}, A, B, C)
AB
end

getindex(M::KronTrav{<:Any,1}, K::Block{1}) = _krontrav_getindex(K, M.args...)
getindex(M::KronTrav{<:Any,2}, K::Block{2}) = _krontrav_getindex(K, M.args...)


Expand Down
14 changes: 13 additions & 1 deletion test/test_blockkron.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,21 @@ LinearAlgebra.factorize(A::MyLazyArray) = factorize(A.data)
a = [1,2,3]
b = [4,5,6]
c = [7,8]
@test KronTrav(a,b) == DiagTrav(b*a')
@test KronTrav(a,b) == DiagTrav(b*a') == DiagTrav(kron(a',b))
@test KronTrav(a,c) == [7,8,14,16,21]
@test KronTrav(c,a) == [7,14,8,21,16]

X = rotl90(Matrix(UpperTriangular(randn(3,3)))) # triangle of coefficients
@test KronTrav(a,b)' * DiagTrav(X) b'*X*a sum(b .* X .* a')
end

@testset "3-vectors" begin
a = [1,2,3]
b = [4,5,6]
c = [7,8,9]

X = [k + j + l - 2 3 ? randn() : 0.0 for k=1:3,j=1:3,l=1:3]
@test KronTrav(a,b,c)' * DiagTrav(X) sum(c .* X .* b' .* reshape(a,1,1,3))
end

@testset "matrix" begin
Expand Down

0 comments on commit b71450c

Please sign in to comment.