From 8e63f32c3acb4426c76e5b0d07fd097b4ef9041d Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Thu, 5 Oct 2023 16:03:20 -0700 Subject: [PATCH] Minor revision --- src/ConstLinearLayer.jl | 8 ++++---- src/builder.jl | 11 +---------- src/utils.jl | 8 ++++++++ 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/ConstLinearLayer.jl b/src/ConstLinearLayer.jl index d03ccd6..d2a9f7f 100644 --- a/src/ConstLinearLayer.jl +++ b/src/ConstLinearLayer.jl @@ -2,13 +2,13 @@ import ChainRulesCore: rrule using LuxCore using LuxCore: AbstractExplicitLayer -struct ConstLinearLayer <: AbstractExplicitLayer - op +struct ConstLinearLayer{T} <: AbstractExplicitLayer + op::T end -(l::ConstLinearLayer)(x::AbstractVector) = l.op * x[1:size(l.op,2)] +(l::ConstLinearLayer{T})(x::AbstractVector) where T = l.op * x -(l::ConstLinearLayer)(x::AbstractMatrix) = begin +(l::ConstLinearLayer{T})(x::AbstractMatrix) where T = begin Tmp = l(x[1,:]) for i = 2:size(x,1) Tmp = [Tmp l(x[i,:])] diff --git a/src/builder.jl b/src/builder.jl index c215411..e615144 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -261,7 +261,7 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories= C = _rpi_A2B_matrix(cgen, spec_nlm) end - l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(new_sparse_matrix(C[i],pos[i])) for i = 1:L+1]... ) : ConstLinearLayer(C) + l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(new_sparse_matrix(C[i],pos[i],length(spec_nlm))) for i = 1:L+1]... ) : ConstLinearLayer(C) # C - A2Bmap luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB) # luxchain = Chain(xx2AA = luxchain_tmp, BB = l_sym) @@ -271,15 +271,6 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories= return luxchain, ps, st end -function new_sparse_matrix(C,pos) - col = maximum(pos) - C_new = sparse(zeros(typeof(C[1]),size(C,1),col)) - for i = 1:size(C,1) - C_new[i,pos] = C[i,:] - end - return sparse(C_new) -end - # more constructors equivariant_model equivariant_model(totdeg::Int64, ν::Int64, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true, rSH = false) = equivariant_model(degord2spec(radial; totaldegree = totdeg, order = ν, Lmax=L, islong = islong)[2], radial, L; categories, d, group, islong, rSH) diff --git a/src/utils.jl b/src/utils.jl index d319690..68cf47c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -253,3 +253,11 @@ function degord2spec(radial::Radial_basis; totaldegree, order, Lmax, catagories end get_i(i) = WrappedFunction(t -> t[i]) + +function new_sparse_matrix(C,pos,len) + C_new = sparse(zeros(typeof(C[1]),size(C,1),len)) + for i = 1:size(C,1) + C_new[i,pos] = C[i,:] + end + return sparse(C_new) +end \ No newline at end of file