From 736f88c19c53ec311566bdb63751311212673201 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Fri, 27 Sep 2024 11:52:32 +0100 Subject: [PATCH] 3-vector KronTrav --- src/blockkron.jl | 15 ++++++++++++--- test/test_blockkron.jl | 12 ++++-------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/blockkron.jl b/src/blockkron.jl index 2802e43..993cc8e 100644 --- a/src/blockkron.jl +++ b/src/blockkron.jl @@ -186,6 +186,8 @@ KronTrav(A::AbstractArray...) = KronTrav{mapreduce(eltype, promote_type, A)}(A.. copy(K::KronTrav) = KronTrav(map(copy,K.args), K.axes) axes(A::KronTrav) = A.axes + + function _krontrav_getindex(K::Block{1}, A, B) m,n = length(A), length(B) mn = min(m,n) @@ -199,11 +201,18 @@ function _krontrav_getindex(K::Block{1}, A, B) end end + + function _krontrav_getindex(K::Block{1}, A, B, C) @assert length(A) == length(B) == length(C) # TODO: generalise - n = length(A) - k = Int(K) - vcat(((A[1:(k-j+1)] .* B[(k-j+1):-1:1]) * C[j] for j=1:k)...) + + # make a tuple corresponding to lexigraphical order + ret = Vector{promote_type(eltype(A),eltype(B),eltype(C))}() + n = Int(K) + for k = 1:n, j=1:k + push!(ret, C[n-k+1]B[k-j+1]A[j]) + end + ret end function _krontrav_getindex(K::Block{2}, A, B) diff --git a/test/test_blockkron.jl b/test/test_blockkron.jl index 12da44b..ce6d448 100644 --- a/test/test_blockkron.jl +++ b/test/test_blockkron.jl @@ -99,12 +99,12 @@ LinearAlgebra.factorize(A::MyLazyArray) = factorize(A.data) a = [1,2,3] b = [4,5,6] c = [7,8] - @test KronTrav(a,b) == DiagTrav(b*a') + @test KronTrav(a,b) == DiagTrav(b*a') == DiagTrav(kron(a',b)) @test KronTrav(a,c) == [7,8,14,16,21] @test KronTrav(c,a) == [7,14,8,21,16] X = rotl90(Matrix(UpperTriangular(randn(3,3)))) # triangle of coefficients - @test KronTrav(a,b)' * DiagTrav(X) ≈ b'*X*a ≈ sum(a' .* X .* b) + @test KronTrav(a,b)' * DiagTrav(X) ≈ b'*X*a ≈ sum(b .* X .* a') end @testset "3-vectors" begin @@ -112,12 +112,8 @@ LinearAlgebra.factorize(A::MyLazyArray) = factorize(A.data) b = [4,5,6] c = [7,8,9] - - # X = [k + j + l - 2 ≤ 3 ? randn() : 0.0 for k=1:3,j=1:3,l=1:3] - X = zeros(3,3,3) - X[2,1,2] = 1; - KronTrav(a,b,c)' * DiagTrav(X) - sum(a' .* X .* b .* reshape(c,1,1,3)) + X = [k + j + l - 2 ≤ 3 ? randn() : 0.0 for k=1:3,j=1:3,l=1:3] + @test KronTrav(a,b,c)' * DiagTrav(X) ≈ sum(c .* X .* b' .* reshape(a,1,1,3)) end @testset "matrix" begin