Skip to content

Commit

Permalink
Update interface to allow input to be State
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglw0521 committed Oct 20, 2023
1 parent 7f032f4 commit ca3e1a2
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 42 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.0.2"
ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DecoratedParticles = "023d0394-cb16-4d2d-a5c7-724bed42bbb6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Expand Down
43 changes: 15 additions & 28 deletions src/builder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ d: Input dimension
categories : A list of categories
"""

function xx2AA(spec_nlm, radial::Radial_basis; categories=[], d=3, rSH = false) # Configuration to AA bases - this is what all chains have in common
function xx2AA(spec_nlm, radial::Radial_basis; categories=[], _get_cat = _get_cat_default, d=3, rSH = false) # Configuration to AA bases - this is what all chains have in common
# from spec_nlm to all possible spec1p
spec1p, lmax, nmax = specnlm2spec1p(spec_nlm)
# An assertation whether all the radial specs are in spec1p
Expand All @@ -146,6 +146,7 @@ function xx2AA(spec_nlm, radial::Radial_basis; categories=[], d=3, rSH = false)
# Define categorical bases
δs = CategoricalBasis(categories)
l_δs = P4ML.lux(δs)
l_δs = append_layer(Chain(get_cat = WrappedFunction(_get_cat), ), l_δs; l_name = :categorical)
end

spec1pidx = isempty(categories) ? getspec1idx(spec1p, radial.Radialspec, Ylm) : getspec1idx(spec1p, radial.Radialspec, Ylm, δs)
Expand All @@ -158,31 +159,17 @@ function xx2AA(spec_nlm, radial::Radial_basis; categories=[], d=3, rSH = false)
# wrapping into lux layers
l_Rnl = radial.Rnl
l_Ylm = P4ML.lux(Ylm)
l_Ylm = append_layer(Chain(get_pos = WrappedFunction(x -> [ x[i].rr for i = 1:length(x)]), ), l_Ylm; l_name = :angle_poly)
l_bA = P4ML.lux(bA)
l_bAA = P4ML.lux(bAA)

Spec_after = Polynomials4ML.reconstruct_spec(l_bAA.basis)
@assert Spec == Spec_after

dict = Dict([Spec_after[i] => i for i = 1 : length(Spec_after)])
pos = [ dict[sort(Spec[i])] for i = 1:length(Spec) ]

# formming model with Lux Chain
_norm(x) = norm.(x)

if isempty(categories)
l_embed = Lux.Parallel(nothing; Rn = l_Rnl, Ylm = l_Ylm)
luxchain = Chain(embed = l_embed, A = l_bA , AA = l_bAA)
else
l_Rnl = append_layer(Chain(get_pos = get_i(1), ), l_Rnl; l_name = :radial_poly)
l_Ylm = append_layer(Chain(get_pos = get_i(1), ), l_Ylm; l_name = :angle_poly)
l_δs = append_layer(Chain(get_cat = get_i(2), ), l_δs; l_name = :categorical)

l_embed = Lux.Parallel(nothing; Rn = l_Rnl, Ylm = l_Ylm, δs = l_δs)
luxchain = Chain(embed = l_embed, A = l_bA , AA = l_bAA) # Chain(l_xnxz = l_xnxz, embed = l_embed, A = l_bA , AA = l_bAA)
end

# luxchain = Chain(l_xnxz = l_xnxz, embed = l_embed, A = l_bA , AA = l_bAA)
# forming model with Lux Chain
l_embed = isempty(categories) ? Lux.Parallel(nothing; Rn = l_Rnl, Ylm = l_Ylm) : Lux.Parallel(nothing; Rn = l_Rnl, Ylm = l_Ylm, δs = l_δs)
luxchain = Chain(embed = l_embed, A = l_bA , AA = l_bAA)

ps, st = Lux.setup(MersenneTwister(1234), luxchain)

return luxchain, ps, st
Expand All @@ -196,7 +183,7 @@ L : Largest equivariance level
categories : A list of categories
radial_basis : specified radial basis, default using P4ML.legendre_basis
"""
function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true, rSH = false)
function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=[], _get_cat = _get_cat_default, d=3, group="O3", islong=true, rSH = false)
if rSH && L > 0
error("rSH is only implemented (for now) for L = 0")
end
Expand All @@ -208,7 +195,7 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=
# sort!(spec_nlm, by = x -> length(x))
spec_nlm = closure(spec_nlm,filter_init; categories = categories)

luxchain_tmp, ps_tmp, st_tmp = EquivariantModels.xx2AA(spec_nlm, radial; categories = categories, d = d, rSH = rSH)
luxchain_tmp, ps_tmp, st_tmp = EquivariantModels.xx2AA(spec_nlm, radial; categories = categories, _get_cat = _get_cat, d = d, rSH = rSH)
F(X) = luxchain_tmp(X, ps_tmp, st_tmp)[1]

if islong
Expand Down Expand Up @@ -241,11 +228,11 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=
end

# more constructors equivariant_model
equivariant_model(totdeg::Int64, ν::Int64, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true, rSH = false) =
equivariant_model(totdeg::Int64, ν::Int64, radial::Radial_basis, L::Int64; categories=[], _get_cat = _get_cat_default, d=3, group="O3", islong=true, rSH = false) =
equivariant_model(degord2spec(radial; totaldegree = totdeg, order = ν, Lmax=L, islong = islong)[2], radial, L; categories, d, group, islong, rSH)

# With the _close function, the input could simply be an nnlllist (nlist,llist)
equivariant_model(nn::Vector{Int64}, ll::Vector{Int64}, radial::Radial_basis, L::Int64; categories=[], d=3, group = "O3", islong = true, rSH = false) = begin
equivariant_model(nn::Vector{Int64}, ll::Vector{Int64}, radial::Radial_basis, L::Int64; categories=[], _get_cat = _get_cat_default, d=3, group = "O3", islong = true, rSH = false) = begin
filter = islong ? RPE_filter_long(L) : RPE_filter(L)
equivariant_model(_close(nn, ll; filter = filter), radial, L; categories, d, group, islong, rSH)
end
Expand All @@ -259,14 +246,14 @@ end
# What can be adjusted in its input are: (1) total polynomial degree; (2) correlation order; (3) largest L
# (4) weight of the order of spherical harmonics; (5) specified radial basis

function equivariant_SYY_model(spec_nlm, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3")
function equivariant_SYY_model(spec_nlm, radial::Radial_basis, L::Int64; categories=[], _get_cat = _get_cat_default, d=3, group="O3")
filter_init = RPE_filter_long(L)
spec_nlm = spec_nlm[findall(x -> filter_init(x) == 1, spec_nlm)]

# sort!(spec_nlm, by = x -> length(x))
spec_nlm = closure(spec_nlm, filter_init; categories = categories)

luxchain_tmp, ps_tmp, st_tmp = EquivariantModels.xx2AA(spec_nlm, radial; categories = categories, d = d)
luxchain_tmp, ps_tmp, st_tmp = EquivariantModels.xx2AA(spec_nlm, radial; categories = categories, _get_cat = _get_cat, d = d)
F(X) = luxchain_tmp(X, ps_tmp, st_tmp)[1]

cgen = Rot3DCoeffs_long(L) # TODO: this should be made group related
Expand All @@ -282,10 +269,10 @@ function equivariant_SYY_model(spec_nlm, radial::Radial_basis, L::Int64; categor
return luxchain, ps, st
end

equivariant_SYY_model(totdeg::Int64, ν::Int64, radial::Radial_basis, L::Int64; categories=[], d=3,group = "O3") =
equivariant_SYY_model(totdeg::Int64, ν::Int64, radial::Radial_basis, L::Int64; categories=[], _get_cat = _get_cat_default, d=3,group = "O3") =
equivariant_SYY_model(degord2spec(radial; totaldegree = totdeg, order = ν, Lmax = L, islong=true)[2], radial, L; categories, d, group)

equivariant_SYY_model(nn::Vector{Int64}, ll::Vector{Int64}, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3") =
equivariant_SYY_model(nn::Vector{Int64}, ll::Vector{Int64}, radial::Radial_basis, L::Int64; categories=[], _get_cat = _get_cat_default, d=3, group="O3") =
equivariant_SYY_model(_close(nn, ll; filter = RPE_filter_long(L)), radial, L; categories, d, group)

## TODO: The following should eventually go into ACEhamiltonians.jl rather than this package
Expand Down
3 changes: 2 additions & 1 deletion src/radial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@ function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=r->1,f_tr

f(r) = f_trans(r) * f_cut(r)

return Radial_basis(Chain(getnorm = WrappedFunction(x -> norm.(x)), trans = WrappedFunction(x -> f.(x)), poly = lux(basis), ), spec)
_norm(x) = try norm(x); catch; norm(x.rr); end
return Radial_basis(Chain(getnorm = WrappedFunction(x -> _norm.(x)), trans = WrappedFunction(x -> f.(x)), poly = lux(basis), ), spec)
end
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,4 +252,4 @@ function degord2spec(radial::Radial_basis; totaldegree, order, Lmax, catagories
return Aspec, AAspec # Aspecgetspecnlm(spec1p, spec)
end

get_i(i) = WrappedFunction(t -> t[i])
_get_cat_default(x) = [ (x[i].Zi,x[i].Zj) for i = 1:length(x) ]
7 changes: 4 additions & 3 deletions test/test_equiv_with_cate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Polynomials4ML, StaticArrays, EquivariantModels, Test, Rotations, LinearAl
using ACEbase.Testing: print_tf
using EquivariantModels: getspec1idx, _invmap, dropnames, SList, val2i, xx2AA, degord2spec, simple_radial_basis
using Polynomials4ML: lux
using DecoratedParticles

include("wigner.jl")

Expand Down Expand Up @@ -35,13 +36,13 @@ Species = [ (species[1], species[i]) for i = 1:10 ]
for ntest = 1:10
local X, θ1, θ2, θ3, Q, QX
X = [ @SVector(rand(3)) for i in 1:10 ]
XX = [X, Species]
XX = [State(rr = X[i], Zi = Species[i][1], Zj = Species[i][2]) for i = 1:length(X)]
θ1 = rand() * 2pi
θ2 = rand() * 2pi
θ3 = rand() * 2pi
Q = RotXYZ(θ1, θ2, θ3)
QX = [SVector{3}(x) for x in Ref(Q) .* X]
QXX = [QX, Species]
QXX = [State(rr = Q * X[i], Zi = Species[i][1], Zj = Species[i][2]) for i = 1:length(X)]
# QXX = [QX, Species]

print_tf(@test F(XX)[1] F(QXX)[1])

Expand Down
27 changes: 21 additions & 6 deletions test/test_equivariance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using ACEbase.Testing: print_tf
using Rotations, WignerD, BlockDiagonals
using EquivariantModels: Radial_basis
using Polynomials4ML:lux
using DecoratedParticles

include("wigner.jl")

Expand All @@ -24,13 +25,15 @@ for L = 0:Lmax
@info("Tesing L = $L O(3) equivariance")
for _ = 1:30
local X, θ1, θ2, θ3, Q, QX
X = [ @SVector(rand(3)) for i in 1:10 ]
X = [ @SVector(rand(3)) for i in 1:10 ]
θ1 = rand() * 2pi
θ2 = rand() * 2pi
θ3 = rand() * 2pi
Q = RotXYZ(θ1, θ2, θ3)
# Q = rand_rot()
QX = [SVector{3}(x) for x in Ref(Q) .* X]
# QX = [SVector{3}(x) for x in Ref(Q) .* X]
QX = [ State(rr = Q * X[i]) for i in 1:length(X) ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
D = wigner_D(L,Matrix(Q))'
# D = wignerD(L, θ, θ, θ)
if L == 0
Expand All @@ -45,6 +48,7 @@ for L = 0:Lmax
for _ = 1:30
local X
X = [ @SVector(rand(3)) for i in 1:10 ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
print_tf(@test F(X) F2(X))
end
println()
Expand All @@ -69,7 +73,9 @@ for ntest = 1:10
θ2 = rand() * 2pi
θ3 = rand() * 2pi
Q = RotXYZ(θ1, θ2, θ3)
QX = [SVector{3}(x) for x in Ref(Q) .* X]
# QX = [SVector{3}(x) for x in Ref(Q) .* X]
QX = [ State(rr = Q * X[i]) for i in 1:length(X) ]
X = [ State(rr = X[i]) for i in 1:length(X) ]

print_tf(@test F(X)[1] F(QX)[1])

Expand Down Expand Up @@ -98,6 +104,7 @@ for l = 0:Lmax

for ntest = 1:20
X = [ @SVector(rand(3)) for i in 1:10 ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
print_tf(@test F(X)[l+1] == FF(X))
end
println()
Expand All @@ -107,6 +114,7 @@ end
for _ = 1:10
local X
X = [ @SVector(rand(3)) for i in 1:10 ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
print_tf(@test length(F(X)) == length(F2(X)) && all([F(X)[i] F2(X)[i] for i = 1:length(F(X))]))
end
println()
Expand All @@ -131,7 +139,9 @@ for L = 0:Lmax
θ = rand() * 2pi
Q = RotXYZ(0, 0, θ)
# Q = rand_rot()
QX = [SVector{3}(x) for x in Ref(Q) .* X]
# QX = [SVector{3}(x) for x in Ref(Q) .* X]
QX = [ State(rr = Q * X[i]) for i in 1:length(X) ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
D = wignerD(L, 0, 0, θ)
if length(F(X)) == 0
continue
Expand Down Expand Up @@ -166,7 +176,9 @@ for ntest = 1:20
θ2 = rand() * 2pi
θ3 = rand() * 2pi
Q = RotXYZ(θ1, θ2, θ3)
QX = [SVector{3}(x) for x in Ref(Q) .* X]
# QX = [SVector{3}(x) for x in Ref(Q) .* X]
QX = [ State(rr = Q * X[i]) for i in 1:length(X) ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
D = BlockDiagonal([ wigner_D(l,Matrix(Q))' for l = 0:L] )

print_tf(@test Ref(D) .* F(QX) F(X))
Expand All @@ -177,6 +189,7 @@ println()
for _ = 1:10
local X
X = [ @SVector(rand(3)) for i in 1:10 ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
print_tf(@test length(F(X)) == length(F2(X)) && all([F(X)[i] F2(X)[i] for i = 1:length(F(X))]))
end
println()
Expand All @@ -201,7 +214,9 @@ for ntest = 1:20
θ2 = rand() * 2pi
θ3 = rand() * 2pi
Q = RotXYZ(θ1, θ2, θ3)
QX = [SVector{3}(x) for x in Ref(Q) .* X]
# QX = [SVector{3}(x) for x in Ref(Q) .* X]
QX = [ State(rr = Q * X[i]) for i in 1:length(X) ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
D = BlockDiagonal([ wigner_D(l,Matrix(Q))' for l = 0:L] )

print_tf(@test Ref(D) .* F(QX) F(X))
Expand Down
15 changes: 12 additions & 3 deletions test/test_rSH_equivariance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ for L = 0:Lmax
θ3 = rand() * 2pi
Q = RotXYZ(θ1, θ2, θ3)
# Q = rand_rot()
QX = [SVector{3}(x) for x in Ref(Q) .* X]
# QX = [SVector{3}(x) for x in Ref(Q) .* X]
QX = [ State(rr = Q * X[i]) for i in 1:length(X) ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
D = wigner_D(L,Matrix(Q))'
# D = wignerD(L, θ, θ, θ)

Expand All @@ -46,6 +48,7 @@ for L = 0:Lmax
for _ = 1:30
local X
X = [ @SVector(rand(3)) for i in 1:10 ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
print_tf(@test F(X) F2(X))
end
println()
Expand All @@ -70,7 +73,9 @@ for ntest = 1:10
θ2 = rand() * 2pi
θ3 = rand() * 2pi
Q = RotXYZ(θ1, θ2, θ3)
QX = [SVector{3}(x) for x in Ref(Q) .* X]
# QX = [SVector{3}(x) for x in Ref(Q) .* X]
QX = [ State(rr = Q * X[i]) for i in 1:length(X) ]
X = [ State(rr = X[i]) for i in 1:length(X) ]

print_tf(@test F(X)[1] F(QX)[1])
end
Expand All @@ -93,6 +98,7 @@ for l = 0:Lmax

for ntest = 1:20
X = [ @SVector(rand(3)) for i in 1:10 ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
print_tf(@test F(X)[l+1] == FF(X))
end
println()
Expand All @@ -102,6 +108,7 @@ end
for _ = 1:10
local X
X = [ @SVector(rand(3)) for i in 1:10 ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
print_tf(@test length(F(X)) == length(F2(X)) && all([F(X)[i] F2(X)[i] for i = 1:length(F(X))]))
end
println()
Expand All @@ -126,7 +133,9 @@ for L = 0:Lmax
θ = rand() * 2pi
Q = RotXYZ(0, 0, θ)
# Q = rand_rot()
QX = [SVector{3}(x) for x in Ref(Q) .* X]
# QX = [SVector{3}(x) for x in Ref(Q) .* X]
QX = [ State(rr = Q * X[i]) for i in 1:length(X) ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
D = wignerD(L, 0, 0, θ)
if length(F(X)) == 0
continue
Expand Down

0 comments on commit ca3e1a2

Please sign in to comment.