diff --git a/Project.toml b/Project.toml index 088c94e..985d165 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" diff --git a/src/ConstLinearLayer.jl b/src/ConstLinearLayer.jl index d2a9f7f..52ee758 100644 --- a/src/ConstLinearLayer.jl +++ b/src/ConstLinearLayer.jl @@ -1,5 +1,5 @@ import ChainRulesCore: rrule -using LuxCore +using LuxCore, LinearOperators using LuxCore: AbstractExplicitLayer struct ConstLinearLayer{T} <: AbstractExplicitLayer @@ -33,4 +33,21 @@ function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractArray,ps return NoTangent(), NoTangent(), l.op' * A[1], (op = A[1] * x',), NoTangent() end return val, pb -end \ No newline at end of file +end + +function _linear_operator_L(L, C, pos, len) + if L == 0 + T = ComplexF64 + fL = let C=C, idx=pos#, T=T + (res, aa) -> mul!(res, C, aa[idx]);# try; mul!(res, C, aa[idx]); catch; mul!(zeros(T,size(C,1)), C, aa[idx]); end + end + else + T = SVector{2L+1,ComplexF64} + fL = let C=C, idx=pos#, T=T + (res, aa) -> begin + res[:] .= C * aa[idx] + end + end + end + return LinearOperator{T}(size(C,1), len, false, false, fL, nothing, nothing; S = Vector{T}) + end \ No newline at end of file diff --git a/src/builder.jl b/src/builder.jl index 5cab21b..c58d543 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -260,9 +260,7 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories= cgen = rSH ? Rot3DCoeffs_real(L) : Rot3DCoeffs(L) # TODO: this should be made group related C = _rpi_A2B_matrix(cgen, spec_nlm) end - - # l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(new_sparse_matrix(C[i],pos[i],length(spec_nlm))) for i = 1:L+1]... ) : ConstLinearLayer(C) - l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(C[i]*sparse_trans(pos[i],length(spec_nlm))) for i = 1:L+1]... ) : ConstLinearLayer(C) + l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(_linear_operator_L(i-1,C[i],pos[i],length(spec_nlm))) for i = 1:L+1]... ) : ConstLinearLayer(C) # C - A2Bmap luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB) # luxchain = Chain(xx2AA = luxchain_tmp, BB = l_sym)