Skip to content

Commit

Permalink
3-vector KronTrav
Browse files Browse the repository at this point in the history
  • Loading branch information
dlfivefifty committed Sep 27, 2024
1 parent e64552b commit 736f88c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
15 changes: 12 additions & 3 deletions src/blockkron.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ KronTrav(A::AbstractArray...) = KronTrav{mapreduce(eltype, promote_type, A)}(A..
copy(K::KronTrav) = KronTrav(map(copy,K.args), K.axes)
axes(A::KronTrav) = A.axes



function _krontrav_getindex(K::Block{1}, A, B)
m,n = length(A), length(B)
mn = min(m,n)
Expand All @@ -199,11 +201,18 @@ function _krontrav_getindex(K::Block{1}, A, B)
end
end



function _krontrav_getindex(K::Block{1}, A, B, C)
@assert length(A) == length(B) == length(C) # TODO: generalise
n = length(A)
k = Int(K)
vcat(((A[1:(k-j+1)] .* B[(k-j+1):-1:1]) * C[j] for j=1:k)...)

# make a tuple corresponding to lexigraphical order
ret = Vector{promote_type(eltype(A),eltype(B),eltype(C))}()
n = Int(K)
for k = 1:n, j=1:k
push!(ret, C[n-k+1]B[k-j+1]A[j])
end
ret

Check warning on line 215 in src/blockkron.jl

View check run for this annotation

Codecov / codecov/patch

src/blockkron.jl#L215

Added line #L215 was not covered by tests
end

function _krontrav_getindex(K::Block{2}, A, B)
Expand Down
12 changes: 4 additions & 8 deletions test/test_blockkron.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,25 +99,21 @@ LinearAlgebra.factorize(A::MyLazyArray) = factorize(A.data)
a = [1,2,3]
b = [4,5,6]
c = [7,8]
@test KronTrav(a,b) == DiagTrav(b*a')
@test KronTrav(a,b) == DiagTrav(b*a') == DiagTrav(kron(a',b))
@test KronTrav(a,c) == [7,8,14,16,21]
@test KronTrav(c,a) == [7,14,8,21,16]

X = rotl90(Matrix(UpperTriangular(randn(3,3)))) # triangle of coefficients
@test KronTrav(a,b)' * DiagTrav(X) b'*X*a sum(a' .* X .* b)
@test KronTrav(a,b)' * DiagTrav(X) b'*X*a sum(b .* X .* a')
end

@testset "3-vectors" begin
a = [1,2,3]
b = [4,5,6]
c = [7,8,9]


# X = [k + j + l - 2 ≤ 3 ? randn() : 0.0 for k=1:3,j=1:3,l=1:3]
X = zeros(3,3,3)
X[2,1,2] = 1;
KronTrav(a,b,c)' * DiagTrav(X)
sum(a' .* X .* b .* reshape(c,1,1,3))
X = [k + j + l - 2 3 ? randn() : 0.0 for k=1:3,j=1:3,l=1:3]
@test KronTrav(a,b,c)' * DiagTrav(X) sum(c .* X .* b' .* reshape(a,1,1,3))
end

@testset "matrix" begin
Expand Down

0 comments on commit 736f88c

Please sign in to comment.