Skip to content

Commit

Permalink
get rid of the "position" projectin
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglw0521 committed Oct 5, 2023
1 parent ae937d1 commit 149447d
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/builder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=
end

#TODO:make use [ C[i], pos[i] ] to generate another sparse matrix so that...
l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(C[i],pos[i]) for i = 1:L+1]... ) : ConstLinearLayer(C)
l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(new_sparse_matrix(C[i],pos[i])) for i = 1:L+1]... ) : ConstLinearLayer(C)
# C - A2Bmap
luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB)
# luxchain = Chain(xx2AA = luxchain_tmp, BB = l_sym)
Expand All @@ -272,6 +272,15 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=
return luxchain, ps, st
end

function new_sparse_matrix(C,pos)
col = maximum(pos)
C_new = sparse(zeros(typeof(C[1]),size(C,1),col))
for i = 1:size(C,1)
C_new[i,pos] = C[i,:]
end
return C_new
end

# more constructors equivariant_model
equivariant_model(totdeg::Int64, ν::Int64, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true, rSH = false) =
equivariant_model(degord2spec(radial; totaldegree = totdeg, order = ν, Lmax=L, islong = islong)[2], radial, L; categories, d, group, islong, rSH)
Expand Down

0 comments on commit 149447d

Please sign in to comment.