Skip to content

Commit

Permalink
Remove some Wrappings
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglw0521 committed Sep 22, 2023
1 parent 8e78b5f commit 4591b13
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
1 change: 1 addition & 0 deletions examples/potential/potential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ chain_AA2B, ps2, st2 = equivariant_model(AAspec, maxL)
X = [ @SVector(rand(3)) for i in 1:10 ]

chain_xx2AA(X, ps1, st1)

chain_AA2B(X, ps2, st2)
10 changes: 6 additions & 4 deletions src/builder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ function equivariant_model(spec_nlm, L::Int64, d=3, categories=[]; radial_basis=
l_sym = islong ? Lux.Parallel(nothing, [WrappedFunction(x -> C[i] * x[pos[i]]) for i = 1:L+1]... ) : WrappedFunction(x -> C * x)

# C - A2Bmap
luxchain = Chain(xx2AA = WrappedFunction(x -> F(x)), BB = l_sym)
# luxchain = Chain(xx2AA = WrappedFunction(x -> F(x)), BB = l_sym)
luxchain = Chain(xx2AA = luxchain_tmp, BB = l_sym)

ps, st = Lux.setup(MersenneTwister(1234), luxchain)

Expand Down Expand Up @@ -329,7 +330,7 @@ function degord2spec(;totaldegree,order,Lmax,radial_basis = legendre_basis, wL =
end

equivariant_model(totdeg::Int64, ν::Int64, L::Int64, d=3, categories=[]; radial_basis=legendre_basis, group="O3", islong=true) =
equivariant_model(degord2spec_nlm(totdeg,ν,L; radial_basis=radial_basis,islong=islong),L,d,categories;radial_basis,group,islong)
equivariant_model(degord2spec(;totaldegree=totdeg,order=ν,Lmax=L,radial_basis=radial_basis,islong=islong)[2],L,d,categories;radial_basis,group,islong)

## The following are SYYVector-related codes - which we might want to either use or get rid of someday...

Expand All @@ -353,15 +354,16 @@ function equivariant_SYY_model(spec_nlm, L::Int64, d=3, categories=[]; radial_ba
l_sym = WrappedFunction(x -> C * x)

# C - A2Bmap
luxchain = Chain(xx2AA = WrappedFunction(x -> F(x)), BB = l_sym)
# luxchain = Chain(xx2AA = WrappedFunction(x -> F(x)), BB = l_sym)
luxchain = Chain(xx2AA = luxchain_tmp, BB = l_sym)

ps, st = Lux.setup(MersenneTwister(1234), luxchain)

return luxchain, ps, st
end

equivariant_SYY_model(totdeg::Int64::Int64,L::Int64,d=3,categories=[];radial_basis = legendre_basis,group = "O3") =
equivariant_SYY_model(degord2spec_nlm(totdeg,ν,L; radial_basis=radial_basis,islong=true),L,d,categories;radial_basis,group)
equivariant_SYY_model(degord2spec(;totaldegree=totdeg,order=ν,Lmax=L,radial_basis=radial_basis,islong=true)[2],L,d,categories;radial_basis,group)

equivariant_SYY_model(nn::Vector{Int64}, ll::Vector{Int64}, L::Int64, d=3, categories=[]; radial_basis=legendre_basis, group="O3") =
equivariant_SYY_model(_close(nn,ll,RPE_filter_long(L)),L,d,categories;radial_basis,group)
Expand Down

0 comments on commit 4591b13

Please sign in to comment.