Skip to content

Commit

Permalink
turning C * X[pos] to LO * X with LO a combined LinearOperator
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglw0521 committed Oct 13, 2023
1 parent bdffa55 commit 7223e6c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Expand Down
21 changes: 19 additions & 2 deletions src/ConstLinearLayer.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import ChainRulesCore: rrule
using LuxCore
using LuxCore, LinearOperators
using LuxCore: AbstractExplicitLayer

struct ConstLinearLayer{T} <: AbstractExplicitLayer
Expand Down Expand Up @@ -33,4 +33,21 @@ function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractArray,ps
return NoTangent(), NoTangent(), l.op' * A[1], (op = A[1] * x',), NoTangent()
end
return val, pb
end
end

function _linear_operator_L(L, C, pos, len)
if L == 0
T = ComplexF64
fL = let C=C, idx=pos#, T=T
(res, aa) -> mul!(res, C, aa[idx]);# try; mul!(res, C, aa[idx]); catch; mul!(zeros(T,size(C,1)), C, aa[idx]); end
end
else
T = SVector{2L+1,ComplexF64}
fL = let C=C, idx=pos#, T=T
(res, aa) -> begin
res[:] .= C * aa[idx]
end
end
end
return LinearOperator{T}(size(C,1), len, false, false, fL, nothing, nothing; S = Vector{T})
end
4 changes: 1 addition & 3 deletions src/builder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,7 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=
cgen = rSH ? Rot3DCoeffs_real(L) : Rot3DCoeffs(L) # TODO: this should be made group related
C = _rpi_A2B_matrix(cgen, spec_nlm)
end

# l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(new_sparse_matrix(C[i],pos[i],length(spec_nlm))) for i = 1:L+1]... ) : ConstLinearLayer(C)
l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(C[i]*sparse_trans(pos[i],length(spec_nlm))) for i = 1:L+1]... ) : ConstLinearLayer(C)
l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(_linear_operator_L(i-1,C[i],pos[i],length(spec_nlm))) for i = 1:L+1]... ) : ConstLinearLayer(C)
# C - A2Bmap
luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB)
# luxchain = Chain(xx2AA = luxchain_tmp, BB = l_sym)
Expand Down

0 comments on commit 7223e6c

Please sign in to comment.