Skip to content

Commit

Permalink
Minor revision
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglw0521 committed Oct 5, 2023
1 parent 42ad3a9 commit 8e63f32
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 14 deletions.
8 changes: 4 additions & 4 deletions src/ConstLinearLayer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ import ChainRulesCore: rrule
using LuxCore
using LuxCore: AbstractExplicitLayer

struct ConstLinearLayer <: AbstractExplicitLayer
op
struct ConstLinearLayer{T} <: AbstractExplicitLayer
op::T
end

(l::ConstLinearLayer)(x::AbstractVector) = l.op * x[1:size(l.op,2)]
(l::ConstLinearLayer{T})(x::AbstractVector) where T = l.op * x

(l::ConstLinearLayer)(x::AbstractMatrix) = begin
(l::ConstLinearLayer{T})(x::AbstractMatrix) where T = begin
Tmp = l(x[1,:])
for i = 2:size(x,1)
Tmp = [Tmp l(x[i,:])]
Expand Down
11 changes: 1 addition & 10 deletions src/builder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=
C = _rpi_A2B_matrix(cgen, spec_nlm)
end

l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(new_sparse_matrix(C[i],pos[i])) for i = 1:L+1]... ) : ConstLinearLayer(C)
l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(new_sparse_matrix(C[i],pos[i],length(spec_nlm))) for i = 1:L+1]... ) : ConstLinearLayer(C)
# C - A2Bmap
luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB)
# luxchain = Chain(xx2AA = luxchain_tmp, BB = l_sym)
Expand All @@ -271,15 +271,6 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=
return luxchain, ps, st
end

function new_sparse_matrix(C,pos)
col = maximum(pos)
C_new = sparse(zeros(typeof(C[1]),size(C,1),col))
for i = 1:size(C,1)
C_new[i,pos] = C[i,:]
end
return sparse(C_new)
end

# more constructors equivariant_model
equivariant_model(totdeg::Int64, ν::Int64, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true, rSH = false) =
equivariant_model(degord2spec(radial; totaldegree = totdeg, order = ν, Lmax=L, islong = islong)[2], radial, L; categories, d, group, islong, rSH)
Expand Down
8 changes: 8 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,11 @@ function degord2spec(radial::Radial_basis; totaldegree, order, Lmax, catagories
end

get_i(i) = WrappedFunction(t -> t[i])

function new_sparse_matrix(C,pos,len)
C_new = sparse(zeros(typeof(C[1]),size(C,1),len))
for i = 1:size(C,1)
C_new[i,pos] = C[i,:]
end
return sparse(C_new)
end

0 comments on commit 8e63f32

Please sign in to comment.