From 8f4311c3c6644c7bad83d87278fa2210435ec0dd Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Wed, 27 Sep 2023 15:20:16 -0700 Subject: [PATCH 01/18] More flexible radial basis embedding --- src/utils.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/utils.jl b/src/utils.jl index 73ae1c3..f9c4a53 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -29,6 +29,7 @@ Return a vector of tuples of indices of spec1 w.r.t actual indices (i.e. 1, 2, 3 function getspec1idx(spec1, bRnl, bYlm) spec1idx = Vector{Tuple{Int, Int}}(undef, length(spec1)) spec_Rnl = natural_indices(bRnl); + # TODO: the following line is to be changed to be l-dependent spec_Rnl = [(n = i, ) for i in spec_Rnl] inv_Rnl = _invmap(spec_Rnl) From 5ae753fea3c881d87c790f4c3fbc6284be8d7d34 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Fri, 29 Sep 2023 00:07:43 -0700 Subject: [PATCH 02/18] First running version - a huge clean up needed Also, forces.jl has yet to be fixed currently --- src/builder.jl | 49 ++++++++--------- src/utils.jl | 102 +++++++++++++++++++++++++++++++---- test/test_equiv_with_cate.jl | 11 ++-- test/test_equivariance.jl | 31 ++++++----- 4 files changed, 141 insertions(+), 52 deletions(-) diff --git a/src/builder.jl b/src/builder.jl index aa54760..360edbe 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -120,19 +120,21 @@ end # TODO: symmetry group O(d)? """ xx2AA(spec_nlm, d=3, categories=[]; radial_basis=legendre_basis) -Construct a lux chain that maps a configuration to the corresponding the AA basis +Construct a lux chain that maps a configuration to the corresponding AA basis spec_nlm: Specification of the AA bases +radial : specified radial basis, with both basis and its specification +=== +OptionalField: d: Input dimension categories : A list of categories -radial_basis : specified radial basis, default using P4ML.legendre_basis """ -function xx2AA(spec_nlm; categories=[], d=3, radial_basis = legendre_basis) # Configuration to AA bases - this is what all chains have in common + +function xx2AA(spec_nlm, radial::Radial_basis; categories=[], d=3) # Configuration to AA bases - this is what all chains have in common # from spec_nlm to all possible spec1p spec1p, lmax, nmax = specnlm2spec1p(spec_nlm) dict_spec1p = Dict([spec1p[i] => i for i = 1:length(spec1p)]) Ylm = CYlmBasis(lmax) - Rn = radial_basis(nmax) - # TODO: make it Rnl = radial_basis(nmax,lmax) + # Rn = radial_basis(nmax) if !isempty(categories) # Read categories from x - TODO: discuss which format we like it to be... @@ -145,8 +147,7 @@ function xx2AA(spec_nlm; categories=[], d=3, radial_basis = legendre_basis) # Co l_δs = P4ML.lux(δs) end - spec1pidx = isempty(categories) ? getspec1idx(spec1p, Rn, Ylm) : getspec1idx(spec1p, Rn, Ylm, δs) - # TODO: write getspec1idx for Rnl basis + spec1pidx = isempty(categories) ? getspec1idx_new(spec1p, radial.Radialspec, Ylm) : getspec1idx_new(spec1p, radial.Radialspec, Ylm, δs) bA = P4ML.PooledSparseProduct(spec1pidx) Spec = sort.([ [dict_spec1p[spec_nlm[k][j]] for j = 1:length(spec_nlm[k])] for k = 1:length(spec_nlm) ]) @@ -154,7 +155,7 @@ function xx2AA(spec_nlm; categories=[], d=3, radial_basis = legendre_basis) # Co bAA = P4ML.SparseSymmProd(Spec) # wrapping into lux layers - l_Rn = P4ML.lux(Rn) + l_Rnl = radial.Rnl l_Ylm = P4ML.lux(Ylm) l_bA = P4ML.lux(bA) l_bAA = P4ML.lux(bAA) @@ -170,11 +171,11 @@ function xx2AA(spec_nlm; categories=[], d=3, radial_basis = legendre_basis) # Co if isempty(categories) l_xnx = Lux.Parallel(nothing; normx = WrappedFunction(_norm), x = WrappedFunction(identity)) - l_embed = Lux.Parallel(nothing; Rn = l_Rn, Ylm = l_Ylm) + l_embed = Lux.Parallel(nothing; Rn = l_Rnl, Ylm = l_Ylm) luxchain = Chain(l_xnx = l_xnx, embed = l_embed, A = l_bA , AA = l_bAA) else l_xnxz = Lux.BranchLayer(normx = WrappedFunction(x -> _norm(x[1])), x = WrappedFunction(x -> x[1]), catlist = WrappedFunction(x -> x[2])) - l_embed = Lux.Parallel(nothing; Rn = l_Rn, Ylm = l_Ylm, δs = l_δs) + l_embed = Lux.Parallel(nothing; Rn = l_Rnl, Ylm = l_Ylm, δs = l_δs) luxchain = Chain(l_xnxz = l_xnxz, embed = l_embed, A = l_bA , AA = l_bAA) end @@ -192,7 +193,7 @@ L : Largest equivariance level categories : A list of categories radial_basis : specified radial basis, default using P4ML.legendre_basis """ -function equivariant_model(spec_nlm, L::Int64; categories=[], d=3, radial_basis=legendre_basis, group="O3", islong=true) +function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true) # first filt out those unfeasible spec_nlm filter_init = islong ? RPE_filter_long(L) : RPE_filter(L) spec_nlm = spec_nlm[findall(x -> filter_init(x) == 1, spec_nlm)] @@ -200,7 +201,7 @@ function equivariant_model(spec_nlm, L::Int64; categories=[], d=3, radial_basis= # sort!(spec_nlm, by = x -> length(x)) spec_nlm = closure(spec_nlm,filter_init; categories = categories) - luxchain_tmp, ps_tmp, st_tmp = EquivariantModels.xx2AA(spec_nlm; categories = categories, d = d, radial_basis = radial_basis) + luxchain_tmp, ps_tmp, st_tmp = EquivariantModels.xx2AA(spec_nlm, radial; categories = categories, d = d) F(X) = luxchain_tmp(X, ps_tmp, st_tmp)[1] if islong @@ -233,13 +234,13 @@ function equivariant_model(spec_nlm, L::Int64; categories=[], d=3, radial_basis= end # more constructors equivariant_model -equivariant_model(totdeg::Int64, ν::Int64, L::Int64; categories=[], d=3, radial_basis=legendre_basis, group="O3", islong=true) = - equivariant_model(degord2spec(;totaldegree = totdeg, order = ν, Lmax=L, radial_basis = radial_basis, islong = islong)[2], L; categories, d, radial_basis, group, islong) +equivariant_model(totdeg::Int64, ν::Int64, radial::Radial_basis, L::Int64; categories=[], d=3, radial_basis=legendre_basis, group="O3", islong=true) = + equivariant_model(degord2spec(radial; totaldegree = totdeg, order = ν, Lmax=L, islong = islong)[2], radial, L; categories, d, group, islong) # With the _close function, the input could simply be an nnlllist (nlist,llist) -equivariant_model(nn::Vector{Int64}, ll::Vector{Int64}, L::Int64; categories=[], d=3, radial_basis = legendre_basis, group = "O3", islong = true) = begin +equivariant_model(nn::Vector{Int64}, ll::Vector{Int64}, radial::Radial_basis, L::Int64; categories=[], d=3, group = "O3", islong = true) = begin filter = islong ? RPE_filter_long(L) : RPE_filter(L) - equivariant_model(_close(nn, ll; filter = filter), L; categories, d, radial_basis, group, islong) + equivariant_model(_close(nn, ll; filter = filter), radial, L; categories, d, group, islong) end # ===== Codes that we might remove later ===== @@ -251,14 +252,14 @@ end # What can be adjusted in its input are: (1) total polynomial degree; (2) correlation order; (3) largest L # (4) weight of the order of spherical harmonics; (5) specified radial basis -function equivariant_SYY_model(spec_nlm, L::Int64; categories=[], d=3, radial_basis=legendre_basis, group="O3") +function equivariant_SYY_model(spec_nlm, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3") filter_init = RPE_filter_long(L) spec_nlm = spec_nlm[findall(x -> filter_init(x) == 1, spec_nlm)] # sort!(spec_nlm, by = x -> length(x)) spec_nlm = closure(spec_nlm, filter_init; categories = categories) - luxchain_tmp, ps_tmp, st_tmp = EquivariantModels.xx2AA(spec_nlm; categories = categories, d = d, radial_basis = radial_basis) + luxchain_tmp, ps_tmp, st_tmp = EquivariantModels.xx2AA(spec_nlm, radial; categories = categories, d = d) F(X) = luxchain_tmp(X, ps_tmp, st_tmp)[1] cgen = Rot3DCoeffs_long(L) # TODO: this should be made group related @@ -274,11 +275,11 @@ function equivariant_SYY_model(spec_nlm, L::Int64; categories=[], d=3, radial_ba return luxchain, ps, st end -equivariant_SYY_model(totdeg::Int64, ν::Int64, L::Int64; categories=[], d=3, radial_basis = legendre_basis,group = "O3") = - equivariant_SYY_model(degord2spec(;totaldegree = totdeg, order = ν, Lmax = L, radial_basis = radial_basis, islong=true)[2], L; categories, d, radial_basis, group) +equivariant_SYY_model(totdeg::Int64, ν::Int64, radial::Radial_basis, L::Int64; categories=[], d=3,group = "O3") = + equivariant_SYY_model(degord2spec(radial; totaldegree = totdeg, order = ν, Lmax = L, islong=true)[2], radial, L; categories, d, group) -equivariant_SYY_model(nn::Vector{Int64}, ll::Vector{Int64}, L::Int64; categories=[], d=3, radial_basis=legendre_basis, group="O3") = - equivariant_SYY_model(_close(nn, ll; filter = RPE_filter_long(L)), L; categories, d, radial_basis, group) +equivariant_SYY_model(nn::Vector{Int64}, ll::Vector{Int64}, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3") = + equivariant_SYY_model(_close(nn, ll; filter = RPE_filter_long(L)), radial, L; categories, d, group) ## TODO: The following should eventually go into ACEhamiltonians.jl rather than this package @@ -312,7 +313,7 @@ function equivariant_luxchain_constructor(totdeg, ν, L; wL = 1, Rn = legendre_b Ylm = CYlmBasis(totdeg) - spec1p = make_nlms_spec(Rn, Ylm; totaldegree = totdeg, admissible = (br, by) -> br + wL * by.l <= totdeg) + spec1p = make_nlms_spec(Radial_basis(Polynomials4ML.lux(Rn)), Ylm; totaldegree = totdeg, admissible = (br, by) -> br.n + wL * by.l <= totdeg) spec1p = sort(spec1p, by = (x -> x.n + x.l * wL)) spec1pidx = getspec1idx(spec1p, Rn, Ylm) @@ -370,7 +371,7 @@ end function equivariant_luxchain_constructor_new(totdeg, ν, L; wL = 1, Rn = legendre_basis(totdeg)) Ylm = CYlmBasis(totdeg) - spec1p = make_nlms_spec(Rn, Ylm; totaldegree = totdeg, admissible = (br, by) -> br + wL * by.l <= totdeg) + spec1p = make_nlms_spec(Radial_basis(Polynomials4ML.lux(Rn)), Ylm; totaldegree = totdeg, admissible = (br, by) -> br.n + wL * by.l <= totdeg) spec1p = sort(spec1p, by = (x -> x.n + x.l * wL)) spec1pidx = getspec1idx(spec1p, Rn, Ylm) diff --git a/src/utils.jl b/src/utils.jl index f9c4a53..4f6dc22 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,5 +1,20 @@ using Polynomials4ML: natural_indices +using LuxCore: AbstractExplicitLayer +struct Radial_basis + Rnl::AbstractExplicitLayer + Radialspec::Vector{NamedTuple} +end + +Radial_basis(Rnl::AbstractExplicitLayer, spec_Rnl::Union{Vector{Int}, UnitRange{Int64}}) = + Radial_basis(Rnl, [(n = i, ) for i in spec_Rnl]) + +Radial_basis(Rnl::AbstractExplicitLayer) = + try + Radial_basis(Rnl,natural_indices(Rnl.basis)) + catch + error("The specification of this Radial_basis should be given explicitly!") + end """ _invmap(a::AbstractVector) @@ -30,15 +45,30 @@ function getspec1idx(spec1, bRnl, bYlm) spec1idx = Vector{Tuple{Int, Int}}(undef, length(spec1)) spec_Rnl = natural_indices(bRnl); # TODO: the following line is to be changed to be l-dependent - spec_Rnl = [(n = i, ) for i in spec_Rnl] + if typeof(spec_Rnl[1]) <: Int + spec_Rnl = [(n = i, ) for i in spec_Rnl] + is_l = false + elseif !(typeof(spec_Rnl[1]) <: NamedTuple) + is_l = true + @error("Unexpected type of Rnl - Probably means that it is not defined yet") + end + inv_Rnl = _invmap(spec_Rnl) spec_Ylm = natural_indices(bYlm); inv_Ylm = _invmap(spec_Ylm) spec1idx = Vector{Tuple{Int, Int}}(undef, length(spec1)) - for (i, b) in enumerate(spec1) - spec1idx[i] = (inv_Rnl[dropnames(b, (:m, :l))], inv_Ylm[(l=b.l, m=b.m)]) + + if is_l + for (i, b) in enumerate(spec1) + spec1idx[i] = (inv_Rnl[dropnames(b, (:m, ))], inv_Ylm[(l=b.l, m=b.m)]) + end + else + for (i, b) in enumerate(spec1) + spec1idx[i] = (inv_Rnl[dropnames(b, (:m, :l))], inv_Ylm[(l=b.l, m=b.m)]) + end end + return spec1idx end @@ -60,22 +90,74 @@ function getspec1idx(spec1, bRnl, bYlm, bδs) return spec1idx end +function getspec1idx_new(spec1, spec_Rnl, bYlm) + spec1idx = Vector{Tuple{Int, Int}}(undef, length(spec1)) + # try is_l = isinteger(spec_Rnl[1].l); catch; is_l = false; end + inv_Rnl = _invmap(spec_Rnl) + + spec_Ylm = natural_indices(bYlm); inv_Ylm = _invmap(spec_Ylm) + + spec1idx = Vector{Tuple{Int, Int}}(undef, length(spec1)) + + if length(spec_Rnl[1]) > 1 && haskey(spec_Rnl[1],:l) + for (i, b) in enumerate(spec1) + spec1idx[i] = (inv_Rnl[dropnames(b, (:m, ))], inv_Ylm[(l=b.l, m=b.m)]) + end + else + for (i, b) in enumerate(spec1) + spec1idx[i] = (inv_Rnl[dropnames(b, (:m, :l))], inv_Ylm[(l=b.l, m=b.m)]) + end + end + + return spec1idx +end + +function getspec1idx_new(spec1, spec_Rnl, bYlm, bδs) + spec1idx = Vector{Tuple{Int, Int, Int}}(undef, length(spec1)) + # try is_l = isinteger(spec_Rnl[1].l); catch; is_l = false; end + inv_Rnl = _invmap(spec_Rnl) + + spec_Ylm = natural_indices(bYlm); inv_Ylm = _invmap(spec_Ylm) + + slist = bδs.categories + + spec1idx = Vector{Tuple{Int, Int, Int}}(undef, length(spec1)) + + if length(spec_Rnl[1]) > 1 && haskey(spec_Rnl[1],:l) + for (i, b) in enumerate(spec1) + spec1idx[i] = (inv_Rnl[dropnames(b, (:m, :s))], inv_Ylm[(l=b.l, m=b.m)], val2i(slist, b.s)) + end + else + for (i, b) in enumerate(spec1) + spec1idx[i] = (inv_Rnl[dropnames(b, (:m, :l, :s))], inv_Ylm[(l=b.l, m=b.m)], val2i(slist, b.s)) + end + end + + return spec1idx +end + """ make_nlms_spec(bRnl, bYlm) Return a vector of tuples of indices of spec1 w.r.t naural indices (i.e. (n = ..., l = ..., m = ...) ) of bRnl and bYlm """ -function make_nlms_spec(bRn, bYlm; +function make_nlms_spec(radial::Radial_basis, bYlm; totaldegree::Int64 = -1, admissible = nothing, nnuc = 0) - spec_Rn = natural_indices(bRn) + spec_Rn = radial.Radialspec spec_Ylm = natural_indices(bYlm) spec1 = [] for (iR, br) in enumerate(spec_Rn), (iY, by) in enumerate(spec_Ylm) if admissible(br, by) - push!(spec1, (n = br, l = by.l, m = by.m)) + if haskey(br,:l) + if br.l == by.l + push!(spec1, (n = br.n, l = by.l, m = by.m)) + end + else + push!(spec1, (n = br.n, l = by.l, m = by.m)) + end end end return spec1 @@ -197,13 +279,13 @@ end degord2spec(;totaldegree, order, Lmax, radial_basis = legendre_basis, wL = 1, islong = true) Return a list of AA specifications and A specifications """ -function degord2spec(;totaldegree, order, Lmax, catagories = [], radial_basis = legendre_basis, wL = 1, islong = true) - Rn = radial_basis(totaldegree) +function degord2spec(radial::Radial_basis; totaldegree, order, Lmax, catagories = [], wL = 1, islong = true) + # Rn = radial.radial_basis(totaldegree) Ylm = CYlmBasis(totaldegree) - spec1p = make_nlms_spec(Rn, Ylm; totaldegree = totaldegree, admissible = (br, by) -> br + wL * by.l <= totaldegree) + spec1p = make_nlms_spec(radial, Ylm; totaldegree = totaldegree, admissible = (br, by) -> br.n + wL * by.l <= totaldegree) spec1p = sort(spec1p, by = (x -> x.n + x.l * wL)) - spec1pidx = getspec1idx(spec1p, Rn, Ylm) + spec1pidx = getspec1idx_new(spec1p, radial.Radialspec, Ylm) # define sparse for n-correlations tup2b = vv -> [ spec1p[v] for v in vv[vv .> 0] ] diff --git a/test/test_equiv_with_cate.jl b/test/test_equiv_with_cate.jl index fb49796..68e876d 100644 --- a/test/test_equiv_with_cate.jl +++ b/test/test_equiv_with_cate.jl @@ -1,13 +1,16 @@ using Polynomials4ML, StaticArrays, EquivariantModels, Test, Rotations, LinearAlgebra using ACEbase.Testing: print_tf using EquivariantModels: getspec1idx, _invmap, dropnames, SList, val2i, xx2AA, degord2spec +using Polynomials4ML: lux include("wigner.jl") L = 4 - -Aspec, AAspec = degord2spec(; totaldegree = 4, - order = 2, +totdeg = 4 +ord = 2 +radial = Radial_basis(legendre_basis(totdeg) |> lux) +Aspec, AAspec = degord2spec(radial; totaldegree = totdeg, + order = ord, Lmax = 0, ) cats = [:O,:C] cats_ext = [(:O,:C),(:C,:O),(:O,:O),(:C,:C)] |> unique @@ -22,7 +25,7 @@ _AAspec_tmp2 = [ [(AAspec[i][1]..., s = cats_ext[2]), (AAspec[i][2]..., s = cats append!(AAspec_tmp,_AAspec_tmp) append!(AAspec_tmp,_AAspec_tmp2) -luxchain, ps, st = equivariant_model(AAspec_tmp, L; categories=cats_ext) +luxchain, ps, st = equivariant_model(AAspec_tmp, radial, L; categories=cats_ext) F(X) = luxchain(X, ps, st)[1] species = [ rand(cats) for i = 1:10 ] Species = [ (species[1], species[i]) for i = 1:10 ] diff --git a/test/test_equivariance.jl b/test/test_equivariance.jl index 50a1ab6..4fc843a 100644 --- a/test/test_equivariance.jl +++ b/test/test_equivariance.jl @@ -1,9 +1,8 @@ -using EquivariantModels -using StaticArrays -using Test +using EquivariantModels, StaticArrays, Test, Polynomials4ML, LinearAlgebra using ACEbase.Testing: print_tf using Rotations, WignerD, BlockDiagonals -using LinearAlgebra +using EquivariantModels: Radial_basis +using Polynomials4ML:lux include("wigner.jl") @@ -11,13 +10,14 @@ include("wigner.jl") totdeg = 6 ν = 2 Lmax = 2 +radial = Radial_basis(legendre_basis(totdeg) |> lux) for L = 0:Lmax local F, luxchain, ps, st, F2, luxchain2, ps2, st2 - luxchain, ps, st = equivariant_model(totdeg, ν, L;islong = false) + luxchain, ps, st = equivariant_model(totdeg, ν, radial, L;islong = false) F(X) = luxchain(X, ps, st)[1] - luxchain2, ps2, st2 = equivariant_model(EquivariantModels.degord2spec(;totaldegree=totdeg,order=ν,Lmax=L,islong = true)[2][1:end-1],L;islong = false) + luxchain2, ps2, st2 = equivariant_model(EquivariantModels.degord2spec(radial;totaldegree=totdeg,order=ν,Lmax=L,islong = true)[2][1:end-1],radial,L;islong = false) F2(X) = luxchain(X, ps2, st2)[1] @info("Tesing L = $L O(3) equivariance") @@ -54,9 +54,10 @@ end totdeg = 6 ν = 2 L = Lmax -luxchain, ps, st = equivariant_model(totdeg,ν,L;islong = true) +radial = Radial_basis(legendre_basis(totdeg) |> lux) +luxchain, ps, st = equivariant_model(totdeg,ν,radial,L;islong = true) F(X) = luxchain(X, ps, st)[1] -luxchain2, ps2, st2 = equivariant_model(EquivariantModels.degord2spec(;totaldegree=totdeg,order=ν,Lmax=L,islong = true)[2][1:end-1],L;islong = true) +luxchain2, ps2, st2 = equivariant_model(EquivariantModels.degord2spec(radial;totaldegree=totdeg,order=ν,Lmax=L,islong = true)[2][1:end-1],radial,L;islong = true) F2(X) = luxchain(X, ps2, st2)[1] for ntest = 1:10 @@ -82,13 +83,14 @@ println() totdeg = 6 ν = 2 L = Lmax -luxchain, ps, st = equivariant_model(totdeg,ν,L;islong = true); +radial = Radial_basis(legendre_basis(totdeg) |> lux) +luxchain, ps, st = equivariant_model(totdeg,ν,radial,L;islong = true); F(X) = luxchain(X, ps, st)[1] for l = 0:Lmax @info("Consistency check for L = $l") local FF, luxchain, ps, st - luxchain, ps, st = equivariant_model(totdeg,ν,l;islong = false) + luxchain, ps, st = equivariant_model(totdeg,ν,radial,l;islong = false) FF(X) = luxchain(X, ps, st)[1] for ntest = 1:20 @@ -116,7 +118,7 @@ for L = 0:Lmax while iseven(L) != iseven(sum(ll)) ll = rand(0:2,4) end - luxchain, ps, st = equivariant_model(nn,ll,L;islong = false) + luxchain, ps, st = equivariant_model(nn,ll,radial,L;islong = false) F(X) = luxchain(X, ps, st)[1] @info("Tesing L = $L O(3) equivariance") @@ -145,9 +147,10 @@ end totdeg = 6 ν = 2 L = Lmax -luxchain, ps, st = equivariant_SYY_model(totdeg,ν,L); +radial = Radial_basis(legendre_basis(totdeg) |> lux) +luxchain, ps, st = equivariant_SYY_model(totdeg,ν,radial,L); F(X) = luxchain(X, ps, st)[1] -luxchain2, ps2, st2 = equivariant_SYY_model(EquivariantModels.degord2spec(;totaldegree=totdeg,order=ν,Lmax=L,islong = true)[2][1:end-1],L) +luxchain2, ps2, st2 = equivariant_SYY_model(EquivariantModels.degord2spec(radial;totaldegree=totdeg,order=ν,Lmax=L,islong = true)[2][1:end-1],radial,L) F2(X) = luxchain(X, ps2, st2)[1] @info("Tesing L = $L O(3) full equivariance") @@ -182,7 +185,7 @@ while iseven(Lmax) != iseven(sum(ll)) global ll = rand(0:2,4) end -luxchain, ps, st = equivariant_SYY_model(nn, ll, L) +luxchain, ps, st = equivariant_SYY_model(nn, ll, radial, L) F(X) = luxchain(X, ps, st)[1] @info("Tesing L = $L O(3) full equivariance") From 0e62cdb67baf6fe77c4f518ad6e3ef5680f108d9 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Fri, 29 Sep 2023 15:13:06 -0700 Subject: [PATCH 03/18] improving the readability of embedding layer --- examples/potential/forces.jl | 43 ++++++++++++++++-------------------- src/builder.jl | 12 +++++----- src/utils.jl | 27 +++++++++++++++++++++- test/test_equiv_with_cate.jl | 9 ++++---- test/test_equivariance.jl | 14 +++++++----- 5 files changed, 66 insertions(+), 39 deletions(-) diff --git a/examples/potential/forces.jl b/examples/potential/forces.jl index 028b710..783242c 100644 --- a/examples/potential/forces.jl +++ b/examples/potential/forces.jl @@ -1,38 +1,33 @@ -using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote +using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote, Polynomials4ML using Polynomials4ML: LinearLayer, RYlmBasis, lux -using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA +using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA, simple_radial_basis rng = Random.MersenneTwister() ## rcut = 5.5 maxL = 0 -Aspec, AAspec = degord2spec(; totaldegree = 6, - order = 3, +totdeg = 6 +ord = 3 +radial = simple_radial_basis(legendre_basis(totdeg)) +Aspec, AAspec = degord2spec(radial; totaldegree = totdeg, + order = ord, Lmax = maxL, ) -l_basis, ps_basis, st_basis = equivariant_model(AAspec, maxL) +l_basis, ps_basis, st_basis = equivariant_model(AAspec, radial, maxL; islong = false) X = [ @SVector(randn(3)) for i in 1:10 ] -B = l_basis(X, ps_basis, st_basis)[1][1] +B = l_basis(X, ps_basis, st_basis)[1] # now build another model with a better transform -L = maximum(b.l for b in Aspec) len_BB = length(B) -get1 = WrappedFunction(t -> t[1]) -embed = Parallel(nothing; - Rn = Chain(trans = WrappedFunction(xx -> [1/(1+norm(x)) for x in xx]), - poly = l_basis.layers.embed.layers.Rn, ), - Ylm = Chain(Ylm = lux(RYlmBasis(L)), ) ) - -model = Chain( - embed = embed, - A = l_basis.layers.A, - AA = l_basis.layers.AA, - # AA_sort = l_basis.layers.AA_sort, - BB = l_basis.layers.BB, - get1 = WrappedFunction(t -> t[1]), - dot = LinearLayer(len_BB, 1), - get2 = WrappedFunction(t -> t[1]), ) +# embed = Parallel(nothing; +# Rn = Chain(trans = WrappedFunction(xx -> [1/(1+norm(x)) for x in xx]), +# poly = l_basis.layers.embed.layers.Rn, ), +# Ylm = Chain(Ylm = lux(RYlmBasis(L)), ) ) + +model = append_layer(l_basis, LinearLayer(len_BB, 1); l_name=:dot) +model = append_layer(model, WrappedFunction(t -> real(t[1])); l_name=:get1) + ps, st = Lux.setup(rng, model) out, st = model(X, ps, st) @@ -158,7 +153,7 @@ end using JuLIP JuLIP.usethreads!(false) -ps.dot.W[:] .= 0.01 * randn(length(ps.dot.W)) +ps.dot.W[:] .= 1e-12 * randn(length(ps.dot.W)) at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1) calc = Pot.LuxCalc(model, ps, st, rcut) @@ -217,4 +212,4 @@ end loss(at, calc, p_vec) -ReverseDiff.gradient(p -> loss(at, calc, p), p_vec) +# ReverseDiff.gradient(p -> loss(at, calc, p), p_vec) diff --git a/src/builder.jl b/src/builder.jl index 360edbe..19414fa 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -170,13 +170,15 @@ function xx2AA(spec_nlm, radial::Radial_basis; categories=[], d=3) # Configurati _norm(x) = norm.(x) if isempty(categories) - l_xnx = Lux.Parallel(nothing; normx = WrappedFunction(_norm), x = WrappedFunction(identity)) l_embed = Lux.Parallel(nothing; Rn = l_Rnl, Ylm = l_Ylm) - luxchain = Chain(l_xnx = l_xnx, embed = l_embed, A = l_bA , AA = l_bAA) - else - l_xnxz = Lux.BranchLayer(normx = WrappedFunction(x -> _norm(x[1])), x = WrappedFunction(x -> x[1]), catlist = WrappedFunction(x -> x[2])) + luxchain = Chain(embed = l_embed, A = l_bA , AA = l_bAA) + else + l_Rnl = append_layer(Chain(get_pos = get_i(1), ), l_Rnl; l_name = :radial_poly) + l_Ylm = append_layer(Chain(get_pos = get_i(1), ), l_Ylm; l_name = :angle_poly) + l_δs = append_layer(Chain(get_cat = get_i(2), ), l_δs; l_name = :categorical) + l_embed = Lux.Parallel(nothing; Rn = l_Rnl, Ylm = l_Ylm, δs = l_δs) - luxchain = Chain(l_xnxz = l_xnxz, embed = l_embed, A = l_bA , AA = l_bAA) + luxchain = Chain(embed = l_embed, A = l_bA , AA = l_bAA) # Chain(l_xnxz = l_xnxz, embed = l_embed, A = l_bA , AA = l_bAA) end # luxchain = Chain(l_xnxz = l_xnxz, embed = l_embed, A = l_bA , AA = l_bAA) diff --git a/src/utils.jl b/src/utils.jl index 4f6dc22..e73cb66 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,4 +1,4 @@ -using Polynomials4ML: natural_indices +using Polynomials4ML: natural_indices, ScalarPoly4MLBasis, lux using LuxCore: AbstractExplicitLayer struct Radial_basis @@ -16,6 +16,29 @@ Radial_basis(Rnl::AbstractExplicitLayer) = error("The specification of this Radial_basis should be given explicitly!") end +# more parameters should be added to this function - it is in its current form just for testing +function simple_radial_basis(basis::ScalarPoly4MLBasis; spec = nothing) + if isnothing(spec) + try + spec = natural_indices(basis) + catch + error("The specification of this Radial_basis should be given explicitly!") + end + end + + function f_cut(r) + return r + end + + function f_tran(r) + return r + end + f(r) = f_cut(r) * f_tran(r) + + return Radial_basis(Chain(trans = WrappedFunction(xx -> [f(norm(x)) for x in xx]), + poly = lux(basis), ), spec) +end + """ _invmap(a::AbstractVector) Return a dictionary that maps the elements of a to their indices @@ -309,3 +332,5 @@ function degord2spec(radial::Radial_basis; totaldegree, order, Lmax, catagories Aspec = specnlm2spec1p(AAspec)[1] return Aspec, AAspec # Aspecgetspecnlm(spec1p, spec) end + +get_i(i) = WrappedFunction(t -> t[i]) diff --git a/test/test_equiv_with_cate.jl b/test/test_equiv_with_cate.jl index 68e876d..02814bb 100644 --- a/test/test_equiv_with_cate.jl +++ b/test/test_equiv_with_cate.jl @@ -1,6 +1,6 @@ using Polynomials4ML, StaticArrays, EquivariantModels, Test, Rotations, LinearAlgebra using ACEbase.Testing: print_tf -using EquivariantModels: getspec1idx, _invmap, dropnames, SList, val2i, xx2AA, degord2spec +using EquivariantModels: getspec1idx, _invmap, dropnames, SList, val2i, xx2AA, degord2spec, simple_radial_basis using Polynomials4ML: lux include("wigner.jl") @@ -8,7 +8,8 @@ include("wigner.jl") L = 4 totdeg = 4 ord = 2 -radial = Radial_basis(legendre_basis(totdeg) |> lux) +radial = simple_radial_basis(legendre_basis(totdeg)) +# radial = Radial_basis(legendre_basis(totdeg) |> lux) Aspec, AAspec = degord2spec(radial; totaldegree = totdeg, order = ord, Lmax = 0, ) @@ -34,13 +35,13 @@ Species = [ (species[1], species[i]) for i = 1:10 ] for ntest = 1:10 local X, θ1, θ2, θ3, Q, QX X = [ @SVector(rand(3)) for i in 1:10 ] - XX = (X, Species) + XX = [X, Species] θ1 = rand() * 2pi θ2 = rand() * 2pi θ3 = rand() * 2pi Q = RotXYZ(θ1, θ2, θ3) QX = [SVector{3}(x) for x in Ref(Q) .* X] - QXX = (QX, Species) + QXX = [QX, Species] print_tf(@test F(XX)[1] ≈ F(QXX)[1]) diff --git a/test/test_equivariance.jl b/test/test_equivariance.jl index 4fc843a..a07c1a6 100644 --- a/test/test_equivariance.jl +++ b/test/test_equivariance.jl @@ -10,7 +10,8 @@ include("wigner.jl") totdeg = 6 ν = 2 Lmax = 2 -radial = Radial_basis(legendre_basis(totdeg) |> lux) +basis = legendre_basis(totdeg) +radial = EquivariantModels.simple_radial_basis(basis) for L = 0:Lmax local F, luxchain, ps, st, F2, luxchain2, ps2, st2 @@ -54,7 +55,8 @@ end totdeg = 6 ν = 2 L = Lmax -radial = Radial_basis(legendre_basis(totdeg) |> lux) +basis = legendre_basis(totdeg) +radial = EquivariantModels.simple_radial_basis(basis) luxchain, ps, st = equivariant_model(totdeg,ν,radial,L;islong = true) F(X) = luxchain(X, ps, st)[1] luxchain2, ps2, st2 = equivariant_model(EquivariantModels.degord2spec(radial;totaldegree=totdeg,order=ν,Lmax=L,islong = true)[2][1:end-1],radial,L;islong = true) @@ -74,7 +76,7 @@ for ntest = 1:10 for l = 2:L D = wigner_D(l-1,Matrix(Q))' # D = wignerD(l-1, 0, 0, θ) - print_tf(@test norm.(Ref(D') .* F(X)[l] - F(QX)[l]) |> norm <1e-12) + print_tf(@test norm.(Ref(D') .* F(X)[l] - F(QX)[l]) |> norm <1e-11) end end println() @@ -83,7 +85,8 @@ println() totdeg = 6 ν = 2 L = Lmax -radial = Radial_basis(legendre_basis(totdeg) |> lux) +basis = legendre_basis(totdeg) +radial = EquivariantModels.simple_radial_basis(basis) luxchain, ps, st = equivariant_model(totdeg,ν,radial,L;islong = true); F(X) = luxchain(X, ps, st)[1] @@ -147,7 +150,8 @@ end totdeg = 6 ν = 2 L = Lmax -radial = Radial_basis(legendre_basis(totdeg) |> lux) +basis = legendre_basis(totdeg) +radial = EquivariantModels.simple_radial_basis(basis) luxchain, ps, st = equivariant_SYY_model(totdeg,ν,radial,L); F(X) = luxchain(X, ps, st)[1] luxchain2, ps2, st2 = equivariant_SYY_model(EquivariantModels.degord2spec(radial;totaldegree=totdeg,order=ν,Lmax=L,islong = true)[2][1:end-1],radial,L) From b2b56bdee4ffcb4360e9267c04aa92b3abc446e4 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Fri, 29 Sep 2023 18:01:51 -0700 Subject: [PATCH 04/18] update forces.jl correspondingly --- examples/potential/forces.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/potential/forces.jl b/examples/potential/forces.jl index 783242c..d5987f2 100644 --- a/examples/potential/forces.jl +++ b/examples/potential/forces.jl @@ -20,13 +20,10 @@ B = l_basis(X, ps_basis, st_basis)[1] # now build another model with a better transform len_BB = length(B) -# embed = Parallel(nothing; -# Rn = Chain(trans = WrappedFunction(xx -> [1/(1+norm(x)) for x in xx]), -# poly = l_basis.layers.embed.layers.Rn, ), -# Ylm = Chain(Ylm = lux(RYlmBasis(L)), ) ) -model = append_layer(l_basis, LinearLayer(len_BB, 1); l_name=:dot) -model = append_layer(model, WrappedFunction(t -> real(t[1])); l_name=:get1) +model = append_layer(l_basis, WrappedFunction(t -> real(t)); l_name=:real) +model = append_layer(model, LinearLayer(len_BB, 1); l_name=:dot) +model = append_layer(model, WrappedFunction(t -> t[1]); l_name=:get1) ps, st = Lux.setup(rng, model) out, st = model(X, ps, st) @@ -213,3 +210,4 @@ loss(at, calc, p_vec) # ReverseDiff.gradient(p -> loss(at, calc, p), p_vec) + From e59f67cae61d21334066d6250652064295693ce4 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Fri, 29 Sep 2023 20:42:48 -0700 Subject: [PATCH 05/18] Clean up + enable the rSH based chain (for L = 0 only) --- examples/potential/forces.jl | 11 ++- src/builder.jl | 60 ++++++++------- src/utils.jl | 78 ++++--------------- test/runtests.jl | 6 +- test/test_rSH_equivariance.jl | 139 ++++++++++++++++++++++++++++++++++ 5 files changed, 195 insertions(+), 99 deletions(-) create mode 100644 test/test_rSH_equivariance.jl diff --git a/examples/potential/forces.jl b/examples/potential/forces.jl index d5987f2..450b326 100644 --- a/examples/potential/forces.jl +++ b/examples/potential/forces.jl @@ -9,7 +9,11 @@ rcut = 5.5 maxL = 0 totdeg = 6 ord = 3 -radial = simple_radial_basis(legendre_basis(totdeg)) + +fcut(rcut::Float64,pin::Int=2,pout::Int=2) = r -> (r < rcut ? abs( (r/rcut)^pin - 1)^pout : 0) +ftrans(r0::Float64=.0,p::Int=2) = r -> ( (1+r0)/(1+r) )^p +radial = simple_radial_basis(legendre_basis(totdeg),fcut(rcut),ftrans()) + Aspec, AAspec = degord2spec(radial; totaldegree = totdeg, order = ord, Lmax = maxL, ) @@ -18,7 +22,7 @@ l_basis, ps_basis, st_basis = equivariant_model(AAspec, radial, maxL; islong = f X = [ @SVector(randn(3)) for i in 1:10 ] B = l_basis(X, ps_basis, st_basis)[1] -# now build another model with a better transform +# now extend the above BB basis to a model len_BB = length(B) model = append_layer(l_basis, WrappedFunction(t -> real(t)); l_name=:real) @@ -150,7 +154,7 @@ end using JuLIP JuLIP.usethreads!(false) -ps.dot.W[:] .= 1e-12 * randn(length(ps.dot.W)) +ps.dot.W[:] .= 1e-2 * randn(length(ps.dot.W)) at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1) calc = Pot.LuxCalc(model, ps, st, rcut) @@ -210,4 +214,3 @@ loss(at, calc, p_vec) # ReverseDiff.gradient(p -> loss(at, calc, p), p_vec) - diff --git a/src/builder.jl b/src/builder.jl index 19414fa..0535923 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -1,6 +1,6 @@ using LinearAlgebra using SparseArrays: SparseMatrixCSC, sparse -using RepLieGroups.O3: Rot3DCoeffs, Rot3DCoeffs_real, Rot3DCoeffs_long, re_basis, SYYVector +using RepLieGroups.O3: Rot3DCoeffs, Rot3DCoeffs_real, Rot3DCoeffs_long, re_basis, SYYVector, mm_filter using Polynomials4ML: legendre_basis, RYlmBasis, natural_indices, degree using Polynomials4ML.Utils: gensparse using Lux: WrappedFunction @@ -16,6 +16,8 @@ P4ML = Polynomials4ML RPE_filter(L) = bb -> (length(bb) == 0) || ((abs(sum(b.m for b in bb)) <= L) && iseven(sum(b.l for b in bb)+L)) RPE_filter_long(L) = bb -> (length(bb) == 0) || (abs(sum(b.m for b in bb)) <= L) +RPE_filter_real(L) = bb -> (length(bb) == 0) || mm_filter([b.m for b in bb],L) && iseven(sum(b.l for b in bb)+L) + """ _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T},Rot3DCoeffs_real{L,T},Rot3DCoeffs_long{L,T}}, spec::Vector{Vector{NamedTuple}}) @@ -36,16 +38,12 @@ function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Ro for i = 1:length(spec) # get the specification of the ith basis function, which is a tuple/vec of NamedTuples pib = spec[i] - # skip it unless all m are zero, because we want to consider each - # (nn, ll) block only once. - # if !all(b.m == 0 for b in pib) - # continue - # end - # But we can not do this anymore for L≥1, so I add an nnllset # get the rotation-coefficients for this basis group # the bs are the basis functions corresponding to the columns + # The nnlllist is created because we want to consider each + # (nn, ll) block only once. nn = SVector([onep.n for onep in pib]...) ll = SVector([onep.l for onep in pib]...) # get a SVector of ll index if haskey(pib[1],:s) @@ -102,7 +100,12 @@ function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Ro if !isnothing(idxAA) push!(Irow, idxB) push!(Jcol, idxAA) - push!(vals, U[irow, icol]) + if norm(U[irow, icol])<1e-12 + push!(vals, real.(U[irow, icol])) + else + push!(vals, U[irow, icol]) + end + # push!(vals, U[irow, icol]) end end end @@ -119,7 +122,7 @@ end # TODO: symmetry group O(d)? """ -xx2AA(spec_nlm, d=3, categories=[]; radial_basis=legendre_basis) +xx2AA(spec_nlm, radial; d=3, categories=[]) Construct a lux chain that maps a configuration to the corresponding AA basis spec_nlm: Specification of the AA bases radial : specified radial basis, with both basis and its specification @@ -129,25 +132,20 @@ d: Input dimension categories : A list of categories """ -function xx2AA(spec_nlm, radial::Radial_basis; categories=[], d=3) # Configuration to AA bases - this is what all chains have in common +function xx2AA(spec_nlm, radial::Radial_basis; categories=[], d=3, rSH = false) # Configuration to AA bases - this is what all chains have in common # from spec_nlm to all possible spec1p spec1p, lmax, nmax = specnlm2spec1p(spec_nlm) dict_spec1p = Dict([spec1p[i] => i for i = 1:length(spec1p)]) - Ylm = CYlmBasis(lmax) + Ylm = rSH ? RYlmBasis(lmax) : CYlmBasis(lmax) # Rn = radial_basis(nmax) if !isempty(categories) - # Read categories from x - TODO: discuss which format we like it to be... - # For now we just give get_cat(x) a random value - #get_cat(x) = length(categories) > 1 ? (iseven(floor(norm(x))) ? categories[1] : categories[2]) : categories[1] - #_get_cat(x) = get_cat.(x) - # Define categorical bases δs = CategoricalBasis(categories) l_δs = P4ML.lux(δs) end - spec1pidx = isempty(categories) ? getspec1idx_new(spec1p, radial.Radialspec, Ylm) : getspec1idx_new(spec1p, radial.Radialspec, Ylm, δs) + spec1pidx = isempty(categories) ? getspec1idx(spec1p, radial.Radialspec, Ylm) : getspec1idx(spec1p, radial.Radialspec, Ylm, δs) bA = P4ML.PooledSparseProduct(spec1pidx) Spec = sort.([ [dict_spec1p[spec_nlm[k][j]] for j = 1:length(spec_nlm[k])] for k = 1:length(spec_nlm) ]) @@ -195,15 +193,19 @@ L : Largest equivariance level categories : A list of categories radial_basis : specified radial basis, default using P4ML.legendre_basis """ -function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true) +function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true, rSH = false) + if rSH && L > 0 + error("rSH is only implemented (for now) for L = 0") + end + # first filt out those unfeasible spec_nlm - filter_init = islong ? RPE_filter_long(L) : RPE_filter(L) + filter_init = rSH ? RPE_filter_real(L) : (islong ? RPE_filter_long(L) : RPE_filter(L)) spec_nlm = spec_nlm[findall(x -> filter_init(x) == 1, spec_nlm)] # sort!(spec_nlm, by = x -> length(x)) spec_nlm = closure(spec_nlm,filter_init; categories = categories) - luxchain_tmp, ps_tmp, st_tmp = EquivariantModels.xx2AA(spec_nlm, radial; categories = categories, d = d) + luxchain_tmp, ps_tmp, st_tmp = EquivariantModels.xx2AA(spec_nlm, radial; categories = categories, d = d, rSH = rSH) F(X) = luxchain_tmp(X, ps_tmp, st_tmp)[1] if islong @@ -212,15 +214,15 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories= pos = Vector{Any}(undef, L+1) for l = 0:L - filter = RPE_filter(l) - cgen = Rot3DCoeffs(l) # TODO: this should be made group related + filter = rSH ? RPE_filter_real(L) : RPE_filter(l) + cgen = rSH ? Rot3DCoeffs_real(l) : Rot3DCoeffs(l) # TODO: this should be made group related tmp = spec_nlm[findall(x -> filter(x) == 1, spec_nlm)] C[l+1] = _rpi_A2B_matrix(cgen, tmp) pos[l+1] = findall(x -> filter(x) == 1, spec_nlm) # [ dict[tmp[j]] for j = 1:length(tmp)] end else - cgen = Rot3DCoeffs(L) # TODO: this should be made group related + cgen = rSH ? Rot3DCoeffs_real(L) : Rot3DCoeffs(L) # TODO: this should be made group related C = _rpi_A2B_matrix(cgen, spec_nlm) end @@ -236,13 +238,13 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories= end # more constructors equivariant_model -equivariant_model(totdeg::Int64, ν::Int64, radial::Radial_basis, L::Int64; categories=[], d=3, radial_basis=legendre_basis, group="O3", islong=true) = - equivariant_model(degord2spec(radial; totaldegree = totdeg, order = ν, Lmax=L, islong = islong)[2], radial, L; categories, d, group, islong) +equivariant_model(totdeg::Int64, ν::Int64, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true, rSH = false) = + equivariant_model(degord2spec(radial; totaldegree = totdeg, order = ν, Lmax=L, islong = islong)[2], radial, L; categories, d, group, islong, rSH) # With the _close function, the input could simply be an nnlllist (nlist,llist) -equivariant_model(nn::Vector{Int64}, ll::Vector{Int64}, radial::Radial_basis, L::Int64; categories=[], d=3, group = "O3", islong = true) = begin +equivariant_model(nn::Vector{Int64}, ll::Vector{Int64}, radial::Radial_basis, L::Int64; categories=[], d=3, group = "O3", islong = true, rSH = false) = begin filter = islong ? RPE_filter_long(L) : RPE_filter(L) - equivariant_model(_close(nn, ll; filter = filter), radial, L; categories, d, group, islong) + equivariant_model(_close(nn, ll; filter = filter), radial, L; categories, d, group, islong, rSH) end # ===== Codes that we might remove later ===== @@ -317,7 +319,7 @@ function equivariant_luxchain_constructor(totdeg, ν, L; wL = 1, Rn = legendre_b spec1p = make_nlms_spec(Radial_basis(Polynomials4ML.lux(Rn)), Ylm; totaldegree = totdeg, admissible = (br, by) -> br.n + wL * by.l <= totdeg) spec1p = sort(spec1p, by = (x -> x.n + x.l * wL)) - spec1pidx = getspec1idx(spec1p, Rn, Ylm) + spec1pidx = getspec1idx(spec1p, Radial_basis(Polynomials4ML.lux(Rn)).Radialspec, Ylm) # define sparse for n-correlations tup2b = vv -> [ spec1p[v] for v in vv[vv .> 0] ] @@ -375,7 +377,7 @@ function equivariant_luxchain_constructor_new(totdeg, ν, L; wL = 1, Rn = legend spec1p = make_nlms_spec(Radial_basis(Polynomials4ML.lux(Rn)), Ylm; totaldegree = totdeg, admissible = (br, by) -> br.n + wL * by.l <= totdeg) spec1p = sort(spec1p, by = (x -> x.n + x.l * wL)) - spec1pidx = getspec1idx(spec1p, Rn, Ylm) + spec1pidx = getspec1idx(spec1p, Radial_basis(Polynomials4ML.lux(Rn)).Radialspec, Ylm) # define sparse for n-correlations tup2b = vv -> [ spec1p[v] for v in vv[vv .> 0] ] diff --git a/src/utils.jl b/src/utils.jl index e73cb66..e266c10 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -16,8 +16,8 @@ Radial_basis(Rnl::AbstractExplicitLayer) = error("The specification of this Radial_basis should be given explicitly!") end -# more parameters should be added to this function - it is in its current form just for testing -function simple_radial_basis(basis::ScalarPoly4MLBasis; spec = nothing) +# it is in its current form just for the purpose of testing +function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=identity,f_trans::Function=identity; spec = nothing) if isnothing(spec) try spec = natural_indices(basis) @@ -26,14 +26,7 @@ function simple_radial_basis(basis::ScalarPoly4MLBasis; spec = nothing) end end - function f_cut(r) - return r - end - - function f_tran(r) - return r - end - f(r) = f_cut(r) * f_tran(r) + f(r) = f_cut(r) * f_trans(r) return Radial_basis(Chain(trans = WrappedFunction(xx -> [f(norm(x)) for x in xx]), poly = lux(basis), ), spec) @@ -61,59 +54,10 @@ function dropnames(namedtuple::NamedTuple, names::Tuple{Vararg{Symbol}}) end """ -getspec1idx(spec1, bRnl, bYlm) +getspec1idx(spec1, Radial_basis, bYlm) Return a vector of tuples of indices of spec1 w.r.t actual indices (i.e. 1, 2, 3, ...) of bRnl and bYlm """ -function getspec1idx(spec1, bRnl, bYlm) - spec1idx = Vector{Tuple{Int, Int}}(undef, length(spec1)) - spec_Rnl = natural_indices(bRnl); - # TODO: the following line is to be changed to be l-dependent - if typeof(spec_Rnl[1]) <: Int - spec_Rnl = [(n = i, ) for i in spec_Rnl] - is_l = false - elseif !(typeof(spec_Rnl[1]) <: NamedTuple) - is_l = true - @error("Unexpected type of Rnl - Probably means that it is not defined yet") - end - - inv_Rnl = _invmap(spec_Rnl) - - spec_Ylm = natural_indices(bYlm); inv_Ylm = _invmap(spec_Ylm) - - spec1idx = Vector{Tuple{Int, Int}}(undef, length(spec1)) - - if is_l - for (i, b) in enumerate(spec1) - spec1idx[i] = (inv_Rnl[dropnames(b, (:m, ))], inv_Ylm[(l=b.l, m=b.m)]) - end - else - for (i, b) in enumerate(spec1) - spec1idx[i] = (inv_Rnl[dropnames(b, (:m, :l))], inv_Ylm[(l=b.l, m=b.m)]) - end - end - - return spec1idx -end - -function getspec1idx(spec1, bRnl, bYlm, bδs) - spec1idx = Vector{Tuple{Int, Int, Int}}(undef, length(spec1)) - - spec_Rnl = natural_indices(bRnl) - spec_Rnl = [(n = i, ) for i in spec_Rnl] - inv_Rnl = _invmap(spec_Rnl) - - spec_Ylm = natural_indices(bYlm); inv_Ylm = _invmap(spec_Ylm) - - slist = bδs.categories - - spec1idx = Vector{Tuple{Int, Int, Int}}(undef, length(spec1)) - for (i, b) in enumerate(spec1) - spec1idx[i] = (inv_Rnl[dropnames(b, (:m, :l, :s))], inv_Ylm[(l=b.l, m=b.m)], val2i(slist, b.s)) - end - return spec1idx -end - -function getspec1idx_new(spec1, spec_Rnl, bYlm) +function getspec1idx(spec1, spec_Rnl, bYlm) spec1idx = Vector{Tuple{Int, Int}}(undef, length(spec1)) # try is_l = isinteger(spec_Rnl[1].l); catch; is_l = false; end inv_Rnl = _invmap(spec_Rnl) @@ -135,7 +79,7 @@ function getspec1idx_new(spec1, spec_Rnl, bYlm) return spec1idx end -function getspec1idx_new(spec1, spec_Rnl, bYlm, bδs) +function getspec1idx(spec1, spec_Rnl, bYlm, bδs) spec1idx = Vector{Tuple{Int, Int, Int}}(undef, length(spec1)) # try is_l = isinteger(spec_Rnl[1].l); catch; is_l = false; end inv_Rnl = _invmap(spec_Rnl) @@ -302,20 +246,24 @@ end degord2spec(;totaldegree, order, Lmax, radial_basis = legendre_basis, wL = 1, islong = true) Return a list of AA specifications and A specifications """ -function degord2spec(radial::Radial_basis; totaldegree, order, Lmax, catagories = [], wL = 1, islong = true) +function degord2spec(radial::Radial_basis; totaldegree, order, Lmax, catagories = [], wL = 1, islong = true, rSH = false) # Rn = radial.radial_basis(totaldegree) Ylm = CYlmBasis(totaldegree) spec1p = make_nlms_spec(radial, Ylm; totaldegree = totaldegree, admissible = (br, by) -> br.n + wL * by.l <= totaldegree) spec1p = sort(spec1p, by = (x -> x.n + x.l * wL)) - spec1pidx = getspec1idx_new(spec1p, radial.Radialspec, Ylm) + spec1pidx = getspec1idx(spec1p, radial.Radialspec, Ylm) # define sparse for n-correlations tup2b = vv -> [ spec1p[v] for v in vv[vv .> 0] ] default_admissible = bb -> length(bb) == 0 || sum(b.n for b in bb) + wL * sum(b.l for b in bb) <= totaldegree # to construct SS, SD blocks - filter_ = islong ? RPE_filter_long(Lmax) : RPE_filter(Lmax) + if rSH + filter_ = RPE_filter_real(Lmax) + else + filter_ = islong ? RPE_filter_long(Lmax) : RPE_filter(Lmax) + end specAA = gensparse(; NU = order, tup2b = tup2b, filter = filter_, admissible = default_admissible, diff --git a/test/runtests.jl b/test/runtests.jl index 729f071..6d78f14 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,5 +3,9 @@ using Test @testset "EquivariantModels.jl" begin @testset "CategoricalBasis" begin include("test_categorial.jl") end - @testset "Equivariance" begin include("test_equivariance.jl"); include("test_equiv_with_cate.jl"); end + @testset "Equivariance" begin + include("test_equivariance.jl") + include("test_equiv_with_cate.jl") + include("test_rSH_equivariance.jl") + end end diff --git a/test/test_rSH_equivariance.jl b/test/test_rSH_equivariance.jl new file mode 100644 index 0000000..f03dac1 --- /dev/null +++ b/test/test_rSH_equivariance.jl @@ -0,0 +1,139 @@ +using EquivariantModels, StaticArrays, Test, Polynomials4ML, LinearAlgebra +using ACEbase.Testing: print_tf +using Rotations, WignerD, BlockDiagonals +# using EquivariantModels: Radial_basis +# using Polynomials4ML:lux + +include("wigner.jl") + +@info("Testing the chain that generates a single B basis") +totdeg = 6 +ν = 2 +Lmax = 0 +basis = legendre_basis(totdeg) +radial = EquivariantModels.simple_radial_basis(basis) + +luxchain, ps, st = equivariant_model(totdeg, ν, radial, 0;islong = false, rSH = true) +F(X) = luxchain(X, ps, st)[1] + +for L = 0:Lmax + local F, luxchain, ps, st, F2, luxchain2, ps2, st2 + luxchain, ps, st = equivariant_model(totdeg, ν, radial, L;islong = false, rSH = true) + F(X) = luxchain(X, ps, st)[1] + + luxchain2, ps2, st2 = equivariant_model(EquivariantModels.degord2spec(radial;totaldegree=totdeg,order=ν,Lmax=L,islong = true,rSH = true)[2][1:end-1],radial,L;islong = false,rSH = true) + F2(X) = luxchain(X, ps2, st2)[1] + + @info("Tesing L = $L O(3) equivariance") + for _ = 1:30 + local X, θ1, θ2, θ3, Q, QX + X = [ @SVector(rand(3)) for i in 1:10 ] + θ1 = rand() * 2pi + θ2 = rand() * 2pi + θ3 = rand() * 2pi + Q = RotXYZ(θ1, θ2, θ3) + # Q = rand_rot() + QX = [SVector{3}(x) for x in Ref(Q) .* X] + D = wigner_D(L,Matrix(Q))' + # D = wignerD(L, θ, θ, θ) + + print_tf(@test F(X) ≈ F(QX)) + + end + println() + + @info("Tesing consistency between the two ways of input - in particular the ``closure'' of specifications") + for _ = 1:30 + local X + X = [ @SVector(rand(3)) for i in 1:10 ] + print_tf(@test F(X) ≈ F2(X)) + end + println() + +end + +@info("Testing the chain that generates all B bases") +totdeg = 6 +ν = 2 +L = Lmax +basis = legendre_basis(totdeg) +radial = EquivariantModels.simple_radial_basis(basis) +luxchain, ps, st = equivariant_model(totdeg,ν,radial,L;islong = true) +F(X) = luxchain(X, ps, st)[1] +luxchain2, ps2, st2 = equivariant_model(EquivariantModels.degord2spec(radial;totaldegree=totdeg,order=ν,Lmax=L,islong = true)[2][1:end-1],radial,L;islong = true,rSH = true) +F2(X) = luxchain(X, ps2, st2)[1] + +for ntest = 1:10 + local X, θ1, θ2, θ3, Q, QX + X = [ @SVector(rand(3)) for i in 1:10 ] + θ1 = rand() * 2pi + θ2 = rand() * 2pi + θ3 = rand() * 2pi + Q = RotXYZ(θ1, θ2, θ3) + QX = [SVector{3}(x) for x in Ref(Q) .* X] + + print_tf(@test F(X)[1] ≈ F(QX)[1]) +end +println() + +@info("Consistency check") +totdeg = 6 +ν = 2 +L = Lmax +basis = legendre_basis(totdeg) +radial = EquivariantModels.simple_radial_basis(basis) +luxchain, ps, st = equivariant_model(totdeg,ν,radial,L;islong = true, rSH = true); +F(X) = luxchain(X, ps, st)[1] + +for l = 0:Lmax + @info("Consistency check for L = $l") + local FF, luxchain, ps, st + luxchain, ps, st = equivariant_model(totdeg,ν,radial,l;islong = false, rSH = true) + FF(X) = luxchain(X, ps, st)[1] + + for ntest = 1:20 + X = [ @SVector(rand(3)) for i in 1:10 ] + print_tf(@test F(X)[l+1] == FF(X)) + end + println() +end + +@info("Tesing consistency between the two ways of input - in particular the ``closure'' of specifications") +for _ = 1:10 + local X + X = [ @SVector(rand(3)) for i in 1:10 ] + print_tf(@test length(F(X)) == length(F2(X)) && all([F(X)[i] ≈ F2(X)[i] for i = 1:length(F(X))])) +end +println() + +@info("Tesing the last way of input - given n_list and l_list") + +for L = 0:Lmax + local F, luxchain, ps, st, nn, ll + + nn = rand(0:2,4) + ll = rand(0:2,4) + while iseven(L) != iseven(sum(ll)) + ll = rand(0:2,4) + end + luxchain, ps, st = equivariant_model(nn,ll,radial,L;islong = false, rSH = true) + F(X) = luxchain(X, ps, st)[1] + + @info("Tesing L = $L O(3) equivariance") + for _ = 1:30 + local X, θ, Q, QX + X = [ @SVector(rand(3)) for i in 1:10 ] + θ = rand() * 2pi + Q = RotXYZ(0, 0, θ) + # Q = rand_rot() + QX = [SVector{3}(x) for x in Ref(Q) .* X] + D = wignerD(L, 0, 0, θ) + if length(F(X)) == 0 + continue + end + print_tf(@test F(X) ≈ F(QX)) + + end + println() +end + From 2089fbfac3210dab319a3a9ad32c044fcadda0cb Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Fri, 29 Sep 2023 20:52:52 -0700 Subject: [PATCH 06/18] minor modification --- src/builder.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/builder.jl b/src/builder.jl index 0535923..56a7499 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -100,7 +100,7 @@ function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Ro if !isnothing(idxAA) push!(Irow, idxB) push!(Jcol, idxAA) - if norm(U[irow, icol])<1e-12 + if norm(U[irow, icol] - real.(U[irow, icol]))<1e-12 push!(vals, real.(U[irow, icol])) else push!(vals, U[irow, icol]) From 5851e7ed84ad9d4320a42449f4d5f37d8b93c769 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Sun, 1 Oct 2023 11:26:45 -0700 Subject: [PATCH 07/18] Minor changes --- src/builder.jl | 8 ++++---- src/utils.jl | 14 +++++++++----- test/test_equivariance.jl | 2 +- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/builder.jl b/src/builder.jl index 56a7499..928c79d 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -317,9 +317,9 @@ function equivariant_luxchain_constructor(totdeg, ν, L; wL = 1, Rn = legendre_b Ylm = CYlmBasis(totdeg) - spec1p = make_nlms_spec(Radial_basis(Polynomials4ML.lux(Rn)), Ylm; totaldegree = totdeg, admissible = (br, by) -> br.n + wL * by.l <= totdeg) + spec1p = make_nlms_spec(simple_radial_basis(Rn), Ylm; totaldegree = totdeg, admissible = (br, by) -> br.n + wL * by.l <= totdeg) spec1p = sort(spec1p, by = (x -> x.n + x.l * wL)) - spec1pidx = getspec1idx(spec1p, Radial_basis(Polynomials4ML.lux(Rn)).Radialspec, Ylm) + spec1pidx = getspec1idx(spec1p, simple_radial_basis(Rn).Radialspec, Ylm) # define sparse for n-correlations tup2b = vv -> [ spec1p[v] for v in vv[vv .> 0] ] @@ -375,9 +375,9 @@ end function equivariant_luxchain_constructor_new(totdeg, ν, L; wL = 1, Rn = legendre_basis(totdeg)) Ylm = CYlmBasis(totdeg) - spec1p = make_nlms_spec(Radial_basis(Polynomials4ML.lux(Rn)), Ylm; totaldegree = totdeg, admissible = (br, by) -> br.n + wL * by.l <= totdeg) + spec1p = make_nlms_spec(simple_radial_basis(Rn), Ylm; totaldegree = totdeg, admissible = (br, by) -> br.n + wL * by.l <= totdeg) spec1p = sort(spec1p, by = (x -> x.n + x.l * wL)) - spec1pidx = getspec1idx(spec1p, Radial_basis(Polynomials4ML.lux(Rn)).Radialspec, Ylm) + spec1pidx = getspec1idx(spec1p, simple_radial_basis(Rn).Radialspec, Ylm) # define sparse for n-correlations tup2b = vv -> [ spec1p[v] for v in vv[vv .> 0] ] diff --git a/src/utils.jl b/src/utils.jl index e266c10..55dcee6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,19 +1,23 @@ using Polynomials4ML: natural_indices, ScalarPoly4MLBasis, lux -using LuxCore: AbstractExplicitLayer +using LuxCore: AbstractExplicitContainerLayer struct Radial_basis - Rnl::AbstractExplicitLayer + Rnl::AbstractExplicitContainerLayer Radialspec::Vector{NamedTuple} end -Radial_basis(Rnl::AbstractExplicitLayer, spec_Rnl::Union{Vector{Int}, UnitRange{Int64}}) = +Radial_basis(Rnl::AbstractExplicitContainerLayer, spec_Rnl::Union{Vector{Int}, UnitRange{Int64}}) = Radial_basis(Rnl, [(n = i, ) for i in spec_Rnl]) -Radial_basis(Rnl::AbstractExplicitLayer) = +Radial_basis(Rnl::AbstractExplicitContainerLayer) = try Radial_basis(Rnl,natural_indices(Rnl.basis)) catch - error("The specification of this Radial_basis should be given explicitly!") + try + Radial_basis(Rnl,natural_indices(Rnl.layers.poly.basis)) + catch + error("The specification of this Radial_basis should be given explicitly!") + end end # it is in its current form just for the purpose of testing diff --git a/test/test_equivariance.jl b/test/test_equivariance.jl index a07c1a6..71c6cf5 100644 --- a/test/test_equivariance.jl +++ b/test/test_equivariance.jl @@ -218,7 +218,7 @@ L = Lmax luxchain, ps, st = equivariant_luxchain_constructor(totdeg,ν,L) F(X) = luxchain(X, ps, st)[1] -# A small comparison - long vector does give us some redundent basis... +# A small comparison - long vector does give us some redundant basis... @info("Equivariance test") l1l2set = [(l1,l2) for l1 = 0:L for l2 = 0:L-l1] From 349201a435d713478db3d8837a158cbd7a5e86a1 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Sun, 1 Oct 2023 11:36:45 -0700 Subject: [PATCH 08/18] Moving those Radial_basis relevant code to a new .jl file --- src/EquivariantModels.jl | 1 + src/radial.jl | 37 +++++++++++++++++++++++++++++++++++++ src/utils.jl | 40 ++-------------------------------------- 3 files changed, 40 insertions(+), 38 deletions(-) create mode 100644 src/radial.jl diff --git a/src/EquivariantModels.jl b/src/EquivariantModels.jl index d714d67..1accbb9 100644 --- a/src/EquivariantModels.jl +++ b/src/EquivariantModels.jl @@ -1,5 +1,6 @@ module EquivariantModels +include("radial.jl") include("utils.jl") include("lux_utils.jl") include("categorical.jl") diff --git a/src/radial.jl b/src/radial.jl new file mode 100644 index 0000000..4b59230 --- /dev/null +++ b/src/radial.jl @@ -0,0 +1,37 @@ +using Polynomials4ML: natural_indices, ScalarPoly4MLBasis, lux +using LuxCore: AbstractExplicitContainerLayer + +struct Radial_basis + Rnl::AbstractExplicitContainerLayer + Radialspec::Vector{NamedTuple} +end + +Radial_basis(Rnl::AbstractExplicitContainerLayer, spec_Rnl::Union{Vector{Int}, UnitRange{Int64}}) = + Radial_basis(Rnl, [(n = i, ) for i in spec_Rnl]) + +Radial_basis(Rnl::AbstractExplicitContainerLayer) = + try + Radial_basis(Rnl,natural_indices(Rnl.basis)) + catch + try + Radial_basis(Rnl,natural_indices(Rnl.layers.poly.basis)) + catch + error("The specification of this Radial_basis should be given explicitly!") + end + end + +# it is in its current form just for the purpose of testing - a more specific example can be found in forces.jl +function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=identity,f_trans::Function=identity; spec = nothing) + if isnothing(spec) + try + spec = natural_indices(basis) + catch + error("The specification of this Radial_basis should be given explicitly!") + end + end + + f(r) = f_cut(r) * f_trans(r) + + return Radial_basis(Chain(trans = WrappedFunction(xx -> [f(norm(x)) for x in xx]), + poly = lux(basis), ), spec) +end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index 55dcee6..6928683 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,40 +1,4 @@ -using Polynomials4ML: natural_indices, ScalarPoly4MLBasis, lux -using LuxCore: AbstractExplicitContainerLayer - -struct Radial_basis - Rnl::AbstractExplicitContainerLayer - Radialspec::Vector{NamedTuple} -end - -Radial_basis(Rnl::AbstractExplicitContainerLayer, spec_Rnl::Union{Vector{Int}, UnitRange{Int64}}) = - Radial_basis(Rnl, [(n = i, ) for i in spec_Rnl]) - -Radial_basis(Rnl::AbstractExplicitContainerLayer) = - try - Radial_basis(Rnl,natural_indices(Rnl.basis)) - catch - try - Radial_basis(Rnl,natural_indices(Rnl.layers.poly.basis)) - catch - error("The specification of this Radial_basis should be given explicitly!") - end - end - -# it is in its current form just for the purpose of testing -function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=identity,f_trans::Function=identity; spec = nothing) - if isnothing(spec) - try - spec = natural_indices(basis) - catch - error("The specification of this Radial_basis should be given explicitly!") - end - end - - f(r) = f_cut(r) * f_trans(r) - - return Radial_basis(Chain(trans = WrappedFunction(xx -> [f(norm(x)) for x in xx]), - poly = lux(basis), ), spec) -end +using Polynomials4ML: natural_indices """ _invmap(a::AbstractVector) @@ -58,7 +22,7 @@ function dropnames(namedtuple::NamedTuple, names::Tuple{Vararg{Symbol}}) end """ -getspec1idx(spec1, Radial_basis, bYlm) +getspec1idx(spec1, spec_Rnl, bYlm) Return a vector of tuples of indices of spec1 w.r.t actual indices (i.e. 1, 2, 3, ...) of bRnl and bYlm """ function getspec1idx(spec1, spec_Rnl, bYlm) From 0aafe09216f5ab166ed722d6a77c0b28c3196c7e Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Sun, 1 Oct 2023 16:03:44 -0700 Subject: [PATCH 09/18] adapting test_potential.jl to the latest version --- examples/potential/test_potential.jl | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/examples/potential/test_potential.jl b/examples/potential/test_potential.jl index 53926af..3ab008e 100644 --- a/examples/potential/test_potential.jl +++ b/examples/potential/test_potential.jl @@ -27,10 +27,16 @@ rng = Random.MersenneTwister() rcut = 5.5 maxL = 0 -L = 0 -Aspec, AAspec = degord2spec(; totaldegree = 6, - order = 3, - Lmax = 0, ) +totdeg = 6 +ord = 3 + +fcut(rcut::Float64,pin::Int=2,pout::Int=2) = r -> (r < rcut ? abs( (r/rcut)^pin - 1)^pout : 0) +ftrans(r0::Float64=.0,p::Int=2) = r -> ( (1+r0)/(1+r) )^p +radial = simple_radial_basis(legendre_basis(totdeg),fcut(rcut),ftrans()) + +Aspec, AAspec = degord2spec(radial; totaldegree = totdeg, + order = ord, + Lmax = maxL, ) cats = AtomicNumber.([:W, :Cu, :Ni, :Fe, :Al]) ipairs = collect(Combinatorics.permutations(1:length(cats), 2)) allcats = collect(SVector{2}.(Combinatorics.permutations(cats, 2))) @@ -52,7 +58,7 @@ for bb in ori_AAspec push!(new_AAspec, newbb) end -luxchain, ps, st = equivariant_model(new_AAspec, L; categories=allcats, islong = false) +luxchain, ps, st = equivariant_model(new_AAspec, radial, L; categories=allcats, islong = false) at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1) iCu = [5, 12]; iNi = [3, 8]; iAl = [10]; iFe = [6]; @@ -67,7 +73,7 @@ get_Z0S(zz0, ZZS) = [SVector{2}(zz0, zzs) for zzs in ZZS] Z0S = get_Z0S(z0, Zs) # input of luxmodel -X = (Rs, Z0S) +X = [Rs, Z0S] out, st = luxchain(X, ps, st) @@ -85,8 +91,8 @@ model(X, ps, st) g = Zygote.gradient(X -> model(X, ps, st)[1], X)[1] -F(Rs) = model((Rs, Z0S), ps, st)[1] -dF(Rs) = Zygote.gradient(rs -> model((rs, Z0S), ps, st)[1], Rs)[1] +F(Rs) = model([Rs, Z0S], ps, st)[1] +dF(Rs) = Zygote.gradient(rs -> model([rs, Z0S], ps, st)[1], Rs)[1] ## @info("test derivative w.r.t X") From e57576c8123d3ce1b1513996f3f42ae35dd71d9c Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Sun, 1 Oct 2023 16:08:41 -0700 Subject: [PATCH 10/18] Minor changes --- examples/potential/test_potential.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/potential/test_potential.jl b/examples/potential/test_potential.jl index 3ab008e..9a38c4e 100644 --- a/examples/potential/test_potential.jl +++ b/examples/potential/test_potential.jl @@ -17,7 +17,7 @@ function grad_test2(f, df, X::AbstractVector; verbose = true) verbose && @printf("---------|----------- \n") verbose && @printf(" h | error \n") verbose && @printf("---------|----------- \n") - for h in 0.1.^(-3:9) + for h in 0.1.^(0:12) gh = [ (f(X + h * EE[:, i]) - F) / h for i = 1:nX ] verbose && @printf(" %.1e | %.2e \n", h, norm(gh - ∇F, Inf)) end @@ -97,6 +97,7 @@ dF(Rs) = Zygote.gradient(rs -> model([rs, Z0S], ps, st)[1], Rs)[1] ## @info("test derivative w.r.t X") print_tf(@test fdtest(F, dF, Rs; verbose=true)) +println() @info("test derivative w.r.t parameter") From d24abeaffdda6114363aa1a5a56f70891be336b6 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Mon, 2 Oct 2023 10:09:37 -0700 Subject: [PATCH 11/18] Minor modifications --- src/radial.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/radial.jl b/src/radial.jl index 4b59230..e58d945 100644 --- a/src/radial.jl +++ b/src/radial.jl @@ -1,15 +1,16 @@ using Polynomials4ML: natural_indices, ScalarPoly4MLBasis, lux -using LuxCore: AbstractExplicitContainerLayer +using LuxCore: AbstractExplicitContainerLayer, AbstractExplicitLayer -struct Radial_basis - Rnl::AbstractExplicitContainerLayer - Radialspec::Vector{NamedTuple} -end +struct Radial_basis <:AbstractExplicitContainerLayer{(:Rnl,)} + Rnl::AbstractExplicitLayer + # make it meta or just leave it as a NameTuple ? + Radialspec::Vector{NamedTuple} + end -Radial_basis(Rnl::AbstractExplicitContainerLayer, spec_Rnl::Union{Vector{Int}, UnitRange{Int64}}) = +Radial_basis(Rnl::AbstractExplicitLayer, spec_Rnl::Union{Vector{Int}, UnitRange{Int64}}) = Radial_basis(Rnl, [(n = i, ) for i in spec_Rnl]) -Radial_basis(Rnl::AbstractExplicitContainerLayer) = +Radial_basis(Rnl::AbstractExplicitLayer) = try Radial_basis(Rnl,natural_indices(Rnl.basis)) catch @@ -21,7 +22,7 @@ Radial_basis(Rnl::AbstractExplicitContainerLayer) = end # it is in its current form just for the purpose of testing - a more specific example can be found in forces.jl -function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=identity,f_trans::Function=identity; spec = nothing) +function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=r->1,f_trans::Function=r->1; spec = nothing) if isnothing(spec) try spec = natural_indices(basis) @@ -30,8 +31,7 @@ function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=identity, end end - f(r) = f_cut(r) * f_trans(r) - - return Radial_basis(Chain(trans = WrappedFunction(xx -> [f(norm(x)) for x in xx]), - poly = lux(basis), ), spec) + return Radial_basis(Chain(trans = WrappedFunction(xx -> [f_trans(norm(x)) for x in xx]), + cutoff = WrappedFunction(xx -> f_cut.(xx)), + poly = lux(basis), ), spec) end \ No newline at end of file From d509a48af5ae6429e0803d0d11515d78a713683f Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Mon, 2 Oct 2023 10:51:40 -0700 Subject: [PATCH 12/18] Typo fix --- src/radial.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/radial.jl b/src/radial.jl index e58d945..606e653 100644 --- a/src/radial.jl +++ b/src/radial.jl @@ -31,7 +31,6 @@ function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=r->1,f_tr end end - return Radial_basis(Chain(trans = WrappedFunction(xx -> [f_trans(norm(x)) for x in xx]), - cutoff = WrappedFunction(xx -> f_cut.(xx)), + return Radial_basis(Chain(trans = WrappedFunction(xx -> [f_trans(norm(x)) * f_cut(norm(x)) for x in xx]), poly = lux(basis), ), spec) end \ No newline at end of file From 9321a06b95b833b61f9c44d74bd374d003ae25f9 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Mon, 2 Oct 2023 12:51:31 -0700 Subject: [PATCH 13/18] Minor revision to adapt reverse over reverse check --- examples/potential/test_potential.jl | 64 ++++++++++++++++++++++++++-- src/radial.jl | 11 ++--- 2 files changed, 67 insertions(+), 8 deletions(-) diff --git a/examples/potential/test_potential.jl b/examples/potential/test_potential.jl index 9a38c4e..a752ccc 100644 --- a/examples/potential/test_potential.jl +++ b/examples/potential/test_potential.jl @@ -1,11 +1,13 @@ -using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote +using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote, Polynomials4ML using Polynomials4ML: LinearLayer, RYlmBasis, lux -using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA +using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA, simple_radial_basis using JuLIP, Combinatorics, Test using ACEbase.Testing: println_slim, print_tf, fdtest using Optimisers: destructure using Printf +L = 0 + include("staticprod.jl") function grad_test2(f, df, X::AbstractVector; verbose = true) @@ -89,7 +91,7 @@ model(X, ps, st) # testing derivative (forces) g = Zygote.gradient(X -> model(X, ps, st)[1], X)[1] - +grad_model(X, ps, st) = Zygote.gradient(_X -> model(_X, ps, st)[1], X)[1] F(Rs) = model([Rs, Z0S], ps, st)[1] dF(Rs) = Zygote.gradient(rs -> model([rs, Z0S], ps, st)[1], Rs)[1] @@ -110,3 +112,59 @@ dFp = w -> ( gl = Zygote.gradient(p -> model(X, p, st)[1], ps)[1]; destructure(g grad_test2(Fp, dFp, W0) +# === define toy loss === +function loss(X, p) + ps = _rest(p) + g = grad_model(X, ps, st) + return sum(norm.(g)) +end + +p_vec, _rest = destructure(ps) + +# === override useful functions === +using Polynomials4ML +import ChainRulesCore: ProjectTo +using ChainRulesCore +using SparseArrays +function Polynomials4ML._pullback_evaluate(∂A, basis::Polynomials4ML.PooledSparseProduct{NB}, BB::Polynomials4ML.TupMat) where {NB} + nX = size(BB[1], 1) + TA = promote_type(eltype.(BB)..., eltype(∂A)) + # @show TA + ∂BB = ntuple(i -> zeros(TA, size(BB[i])...), NB) + Polynomials4ML._pullback_evaluate!(∂BB, ∂A, basis, BB) + return ∂BB +end + +function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray) + dy = if axes(dx) == project.axes + dx + else + if size(dx) != (length(project.axes[1]), length(project.axes[2])) + throw(_projection_mismatch(project.axes, size(dx))) + end + reshape(dx, project.axes) + end + T = promote_type(ChainRulesCore.project_type(project.element), eltype(dx)) + nzval = Vector{T}(undef, length(project.rowval)) + k = 0 + for col in project.axes[2] + for i in project.nzranges[col] + row = project.rowval[i] + val = dy[row, col] + nzval[k += 1] = project.element(val) + end + end + m, n = map(length, project.axes) + return SparseMatrixCSC(m, n, project.colptr, project.rowval, nzval) +end + +# === reverse over reverse === +using ReverseDiff +gg1 = ReverseDiff.gradient(_p -> loss(X, _p), p_vec) + +using ACEbase +ACEbase.Testing.fdtest( + _p -> loss(X, _p), + _p -> ReverseDiff.gradient(__p -> loss1(X, __p), _p), + p_vec ) +## \ No newline at end of file diff --git a/src/radial.jl b/src/radial.jl index 606e653..ecee4db 100644 --- a/src/radial.jl +++ b/src/radial.jl @@ -1,10 +1,10 @@ using Polynomials4ML: natural_indices, ScalarPoly4MLBasis, lux using LuxCore: AbstractExplicitContainerLayer, AbstractExplicitLayer -struct Radial_basis <:AbstractExplicitContainerLayer{(:Rnl,)} - Rnl::AbstractExplicitLayer + struct Radial_basis{T <: AbstractExplicitLayer} <:AbstractExplicitContainerLayer{(:Rnl, )} + Rnl::T # make it meta or just leave it as a NameTuple ? - Radialspec::Vector{NamedTuple} + Radialspec::Vector #{NamedTuple} #TODO: double check this... end Radial_basis(Rnl::AbstractExplicitLayer, spec_Rnl::Union{Vector{Int}, UnitRange{Int64}}) = @@ -31,6 +31,7 @@ function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=r->1,f_tr end end - return Radial_basis(Chain(trans = WrappedFunction(xx -> [f_trans(norm(x)) * f_cut(norm(x)) for x in xx]), - poly = lux(basis), ), spec) + f(r) = f_trans(r) * f_cut(r) + + return Radial_basis(Chain(getnorm = WrappedFunction(x -> norm.(x)), trans = WrappedFunction(x -> f.(x)), poly = lux(basis), ), spec) end \ No newline at end of file From 8e0bdcc6e97109acf01064debcc293c355886640 Mon Sep 17 00:00:00 2001 From: cheukhinhojerry Date: Tue, 3 Oct 2023 22:02:53 -0700 Subject: [PATCH 14/18] bug fix for doublepb in SitePot --- examples/potential/forces_chho.jl | 256 ++++++++++++++++++++++++++++++ 1 file changed, 256 insertions(+) create mode 100644 examples/potential/forces_chho.jl diff --git a/examples/potential/forces_chho.jl b/examples/potential/forces_chho.jl new file mode 100644 index 0000000..ad23490 --- /dev/null +++ b/examples/potential/forces_chho.jl @@ -0,0 +1,256 @@ +using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote, Polynomials4ML +using Polynomials4ML: LinearLayer, RYlmBasis, lux +using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA, simple_radial_basis +rng = Random.MersenneTwister() + +## + +rcut = 5.5 +maxL = 0 +totdeg = 6 +ord = 3 + +fcut(rcut::Float64,pin::Int=2,pout::Int=2) = r -> (r < rcut ? abs( (r/rcut)^pin - 1)^pout : 0) +ftrans(r0::Float64=.0,p::Int=2) = r -> ( (1+r0)/(1+r) )^p +radial = simple_radial_basis(legendre_basis(totdeg),fcut(rcut),ftrans()) + +Aspec, AAspec = degord2spec(radial; totaldegree = totdeg, + order = ord, + Lmax = maxL, ) + +l_basis, ps_basis, st_basis = equivariant_model(AAspec, radial, maxL; islong = false) +X = [ @SVector(randn(3)) for i in 1:10 ] +B = l_basis(X, ps_basis, st_basis)[1] + +# now extend the above BB basis to a model +len_BB = length(B) + +model = append_layer(l_basis, WrappedFunction(t -> real(t)); l_name=:real) +model = append_layer(model, LinearLayer(len_BB, 1); l_name=:dot) +model = append_layer(model, WrappedFunction(t -> t[1]); l_name=:get1) + +ps, st = Lux.setup(rng, model) +out, st = model(X, ps, st) + +# testing derivative (forces) +g = Zygote.gradient(X -> model(X, ps, st)[1], X)[1] + +## + +module Pot + import JuLIP, Zygote, StaticArrays + import JuLIP: cutoff, Atoms + import ACEbase: evaluate!, evaluate_d! + import StaticArrays: SVector, SMatrix + import ReverseDiff + import ChainRulesCore + import ChainRulesCore: rrule, ignore_derivatives + + import Optimisers: destructure + + struct LuxCalc <: JuLIP.SitePotential + luxmodel + ps + st + rcut::Float64 + restructure + end + + function LuxCalc(luxmodel, ps, st, rcut) + pvec, rest = destructure(ps) + return LuxCalc(luxmodel, ps, st, rcut, rest) + end + + cutoff(calc::LuxCalc) = calc.rcut + + function evaluate!(tmp, calc::LuxCalc, Rs, Zs, z0) + E, st = calc.luxmodel(Rs, calc.ps, calc.st) + return E[1] + end + + function evaluate_d!(dEs, tmpd, calc::LuxCalc, Rs, Zs, z0) + g = Zygote.gradient(X -> calc.luxmodel(X, calc.ps, calc.st)[1], Rs)[1] + @assert length(g) == length(Rs) <= length(dEs) + dEs[1:length(g)] .= g + return dEs + end + + # ----- parameter estimation stuff + + + function lux_energy(at::Atoms, calc::LuxCalc, ps::NamedTuple, st::NamedTuple) + nlist = ignore_derivatives() do + JuLIP.neighbourlist(at, calc.rcut) + end + return sum( i -> begin + Js, Rs, Zs = ignore_derivatives() do + JuLIP.Potentials.neigsz(nlist, at, i) + end + Ei, st = calc.luxmodel(Rs, ps, st) + Ei[1] + end, + 1:length(at) + ) + end + + + function lux_efv(at::Atoms, calc::LuxCalc, ps::NamedTuple, st::NamedTuple) + nlist = ignore_derivatives() do + JuLIP.neighbourlist(at, calc.rcut) + end + E = 0.0 + F = zeros(SVector{3, Float64}, length(at)) + V = zero(SMatrix{3, 3, Float64}) + for i = 1:length(at) + Js, Rs, Zs = ignore_derivatives() do + JuLIP.Potentials.neigsz(nlist, at, i) + end + comp = Zygote.withgradient(_X -> calc.luxmodel(_X, ps, st)[1], Rs) + Ei = comp.val + _∇Ei = comp.grad[1] + ∇Ei = ReverseDiff.value.(_∇Ei) + # energy + E += Ei + + # Forces + for j = 1:length(Rs) + F[Js[j]] -= ∇Ei[j] + F[i] += ∇Ei[j] + end + + # Virial + if length(Rs) > 0 + V -= sum(∇Eij * Rij' for (∇Eij, Rij) in zip(∇Ei, Rs)) + end + end + + return E, F, V + end + +# site_virial(dV::AbstractVector{JVec{T1}}, R::AbstractVector{JVec{T2}} +# ) where {T1, T2} = ( +# length(R) > 0 ? (- sum( dVi * Ri' for (dVi, Ri) in zip(dV, R) )) +# : zero(JMat{fltype_intersect(T1, T2)}) +# ) + # function rrule(::typeof(lux_energy), at::Atoms, calc::LuxCalc, ps::NamedTuple, st::NamedTuple) + # E = lux_energy(at, calc, ps, st) + # function pb(Δ) + # nlist = JuLIP.neighbourlist(at, calc.rcut) + # @show Δ + # error("stop") + # function pb_inner(i) + # Js, Rs, Zs = JuLIP.Potentials.neigsz(nlist, at, i) + # gi = ReverseDiff.gradient() + # end + # for i = 1:length(at) + # Ei, st = calc.luxmodel(Rs, calc.ps, calc.st) + # E += Ei[1] + # end + # end + # end + +end + +## + +using JuLIP +JuLIP.usethreads!(false) +ps.dot.W[:] .= 1e-2 * randn(length(ps.dot.W)) + +at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1) +calc = Pot.LuxCalc(model, ps, st, rcut) +JuLIP.energy(calc, at) +JuLIP.forces(calc, at) +JuLIP.virial(calc, at) +Pot.lux_energy(at, calc, ps, st) + +@time JuLIP.energy(calc, at) +@time Pot.lux_energy(at, calc, ps, st) +@time JuLIP.forces(calc, at) + +## + +using Optimisers, ReverseDiff + +p_vec, _rest = destructure(ps) +f(_pvec) = Pot.lux_energy(at, calc, _rest(_pvec), st) + +f(p_vec) +gz = Zygote.gradient(f, p_vec)[1] + +@time f(p_vec) +@time Zygote.gradient(f, p_vec)[1] + +# We can use either Zygote or ReverseDiff for gradients. +gr = ReverseDiff.gradient(f, p_vec) +@show gr ≈ gz + +@info("Interestingly ReverseDiff is much faster here, almost optimal") +@time f(p_vec) +@time Zygote.gradient(f, p_vec)[1] +@time ReverseDiff.gradient(f, p_vec) + +## + +@info("Compute Energies, Forces and Virials at the same time") +E, F, V = Pot.lux_efv(at, calc, ps, st) +@show E ≈ JuLIP.energy(calc, at) +@show F ≈ JuLIP.forces(calc, at) +@show V ≈ JuLIP.virial(calc, at) + +## + +# make up a baby loss function type thing. +function loss(at, calc, p_vec) + ps = _rest(p_vec) + st = calc.st + E, F, V = Pot.lux_efv(at, calc, ps, st) + Nat = length(at) + return (E / Nat)^2 + + sum( f -> sum(abs2, f), F ) / Nat + + sum(abs2, V) +end + +loss(at, calc, p_vec) + +# ==== +using Polynomials4ML +import ChainRulesCore: ProjectTo +using ChainRulesCore +using SparseArrays +function Polynomials4ML._pullback_evaluate(∂A, basis::Polynomials4ML.PooledSparseProduct{NB}, BB::Polynomials4ML.TupMat) where {NB} + nX = size(BB[1], 1) + TA = promote_type(eltype.(BB)..., eltype(∂A)) + # @show TA + ∂BB = ntuple(i -> zeros(TA, size(BB[i])...), NB) + Polynomials4ML._pullback_evaluate!(∂BB, ∂A, basis, BB) + return ∂BB +end + +function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray) + dy = if axes(dx) == project.axes + dx + else + if size(dx) != (length(project.axes[1]), length(project.axes[2])) + throw(_projection_mismatch(project.axes, size(dx))) + end + reshape(dx, project.axes) + end + T = promote_type(ChainRulesCore.project_type(project.element), eltype(dx)) + nzval = Vector{T}(undef, length(project.rowval)) + k = 0 + for col in project.axes[2] + for i in project.nzranges[col] + row = project.rowval[i] + val = dy[row, col] + nzval[k += 1] = project.element(val) + end + end + m, n = map(length, project.axes) + return SparseMatrixCSC(m, n, project.colptr, project.rowval, nzval) +end + + + +## +ReverseDiff.gradient(p -> loss(at, calc, p), p_vec) From 25f487dd4cb12fb173298dd6862de9d1d7b82b58 Mon Sep 17 00:00:00 2001 From: cheukhinhojerry Date: Tue, 3 Oct 2023 23:13:43 -0700 Subject: [PATCH 15/18] add multi species pot - but returning nan --- .../potential/test_potential_multi_chho.jl | 174 ++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 examples/potential/test_potential_multi_chho.jl diff --git a/examples/potential/test_potential_multi_chho.jl b/examples/potential/test_potential_multi_chho.jl new file mode 100644 index 0000000..2bbff30 --- /dev/null +++ b/examples/potential/test_potential_multi_chho.jl @@ -0,0 +1,174 @@ +using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote, Polynomials4ML +using Polynomials4ML: LinearLayer, RYlmBasis, lux +using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA, simple_radial_basis +using JuLIP, Combinatorics, Test +using ACEbase.Testing: println_slim, print_tf, fdtest +using Optimisers: destructure +using Printf + +L = 0 + +include("staticprod.jl") + +function grad_test2(f, df, X::AbstractVector; verbose = true) + F = f(X) + ∇F = df(X) + nX = length(X) + EE = Matrix(I, (nX, nX)) + + verbose && @printf("---------|----------- \n") + verbose && @printf(" h | error \n") + verbose && @printf("---------|----------- \n") + for h in 0.1.^(0:12) + gh = [ (f(X + h * EE[:, i]) - F) / h for i = 1:nX ] + verbose && @printf(" %.1e | %.2e \n", h, norm(gh - ∇F, Inf)) + end +end + +rng = Random.MersenneTwister() + +rcut = 5.5 +maxL = 0 +totdeg = 6 +ord = 3 + +fcut(rcut::Float64,pin::Int=2,pout::Int=2) = r -> (r < rcut ? abs( (r/rcut)^pin - 1)^pout : 0) +ftrans(r0::Float64=.0,p::Int=2) = r -> ( (1+r0)/(1+r) )^p +radial = simple_radial_basis(legendre_basis(totdeg),fcut(rcut),ftrans()) + +Aspec, AAspec = degord2spec(radial; totaldegree = totdeg, + order = ord, + Lmax = maxL, ) +cats = AtomicNumber.([:W, :Cu, :Ni, :Fe, :Al]) + +ipairs = collect(Combinatorics.permutations(1:length(cats), 2)) +allcats = collect(SVector{2}.(Combinatorics.permutations(cats, 2))) + +for (i, cat) in enumerate(cats) + push!(ipairs, [i, i]) + push!(allcats, SVector{2}([cat, cat])) +end + +new_spec = [] +ori_AAspec = deepcopy(AAspec) +new_AAspec = [] + +for bb in ori_AAspec + newbb = [] + for (t, ip) in zip(bb, ipairs) + push!(newbb, (t..., s = cats[ip])) + end + push!(new_AAspec, newbb) +end + +luxchain, ps, st = equivariant_model(new_AAspec, radial, L; categories=allcats, islong = false) + +at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1) +iCu = [5, 12]; iNi = [3, 8]; iAl = [10]; iFe = [6]; +at.Z[iCu] .= cats[2]; at.Z[iNi] .= cats[3]; at.Z[iAl] .= cats[4]; at.Z[iFe] .= cats[5]; +nlist = JuLIP.neighbourlist(at, rcut) +_, Rs, Zs = JuLIP.Potentials.neigsz(nlist, at, 1) +# centere atom +z0 = at.Z[1] + +# serialization, I want the input data structure to lux as simple as possible +get_Z0S(zz0, ZZS) = [SVector{2}(zz0, zzs) for zzs in ZZS] +Z0S = get_Z0S(z0, Zs) + +# input of luxmodel +X = [Rs, Z0S] + +out, st = luxchain(X, ps, st) + + +# == lux chain eval and grad +B = out + +model = append_layers(luxchain, get1 = WrappedFunction(t -> real.(t)), dot = LinearLayer(length(B), 1), get2 = WrappedFunction(t -> t[1])) + +ps, st = Lux.setup(MersenneTwister(1234), model) +ps.dot.W[:] = ps.dot.W[:] / 1000 + +model(X, ps, st) + +# testing derivative (forces) +g = Zygote.gradient(X -> model(X, ps, st)[1], X)[1][1] +grad_model(X, ps, st) = Zygote.gradient(_X -> model(_X, ps, st)[1], X)[1] + +F(Rs) = model([Rs, Z0S], ps, st)[1] +dF(Rs) = Zygote.gradient(rs -> model([rs, Z0S], ps, st)[1], Rs)[1] + +## +@info("test derivative w.r.t X") +print_tf(@test fdtest(F, dF, Rs; verbose=true)) +println() + + +@info("test derivative w.r.t parameter") +p = Zygote.gradient(p -> model(X, p, st)[1], ps)[1] +p, = destructure(p) + +W0, re = destructure(ps) +Fp = w -> model(X, re(w), st)[1] +dFp = w -> ( gl = Zygote.gradient(p -> model(X, p, st)[1], ps)[1]; destructure(gl)[1]) +grad_test2(Fp, dFp, W0) + + +# === define toy loss === +function loss(X, p) + ps = _rest(p) + g = grad_model(X, ps, st)[1] + return sum(norm.(g)) +end + +p_vec, _rest = destructure(ps) + +# === override useful functions === +using Polynomials4ML +import ChainRulesCore: ProjectTo +using ChainRulesCore +using SparseArrays +function Polynomials4ML._pullback_evaluate(∂A, basis::Polynomials4ML.PooledSparseProduct{NB}, BB::Polynomials4ML.TupMat) where {NB} + nX = size(BB[1], 1) + TA = promote_type(eltype.(BB)..., eltype(∂A)) + # @show TA + ∂BB = ntuple(i -> zeros(TA, size(BB[i])...), NB) + Polynomials4ML._pullback_evaluate!(∂BB, ∂A, basis, BB) + return ∂BB +end + +function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray) + dy = if axes(dx) == project.axes + dx + else + if size(dx) != (length(project.axes[1]), length(project.axes[2])) + throw(_projection_mismatch(project.axes, size(dx))) + end + reshape(dx, project.axes) + end + T = promote_type(ChainRulesCore.project_type(project.element), eltype(dx)) + nzval = Vector{T}(undef, length(project.rowval)) + k = 0 + for col in project.axes[2] + for i in project.nzranges[col] + row = project.rowval[i] + val = dy[row, col] + nzval[k += 1] = project.element(val) + end + end + m, n = map(length, project.axes) + return SparseMatrixCSC(m, n, project.colptr, project.rowval, nzval) +end + +# === reverse over reverse === +using ReverseDiff +g1 = ReverseDiff.gradient(_p -> loss(X, _p), p_vec) + +Zygote.gradient(_p -> loss(X, _p), p_vec) + +using ACEbase +ACEbase.Testing.fdtest( + _p -> loss(X, _p), + _p -> Zygote.gradient(__p -> loss(X, __p), _p), + p_vec ) +## \ No newline at end of file From f0ba137d85f46dc50042af7674a701ac5471a768 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Wed, 4 Oct 2023 19:54:47 -0700 Subject: [PATCH 16/18] add an assertion to avoid ambiguity in radial_spec & spec1p --- src/builder.jl | 3 +++ src/utils.jl | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/builder.jl b/src/builder.jl index 928c79d..19fbd81 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -135,6 +135,9 @@ categories : A list of categories function xx2AA(spec_nlm, radial::Radial_basis; categories=[], d=3, rSH = false) # Configuration to AA bases - this is what all chains have in common # from spec_nlm to all possible spec1p spec1p, lmax, nmax = specnlm2spec1p(spec_nlm) + # An assertation whether all the radial specs are in spec1p + @assert issubset(radial.Radialspec, nset(spec1p)) || issubset(radial.Radialspec, nlset(spec1p)) + dict_spec1p = Dict([spec1p[i] => i for i = 1:length(spec1p)]) Ylm = rSH ? RYlmBasis(lmax) : CYlmBasis(lmax) # Rn = radial_basis(nmax) diff --git a/src/utils.jl b/src/utils.jl index 6928683..d319690 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -160,6 +160,9 @@ function specnlm2spec1p(spec_nlm) return spec1p, lmax, nmax + 1 end +nset(spec1p) = [ (n=spec.n,) for spec in spec1p] +nlset(spec1p) = [ (n=spec.n, l=spec.l,) for spec in spec1p] + """ closure(spec_nlm,filter) Make a spec_nlm to be a "complete" set to be symmetrised w.r.t to the filter From ede25451d599311b8c3d35d913aaad26814b5b9a Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Wed, 4 Oct 2023 19:59:45 -0700 Subject: [PATCH 17/18] Typo fix --- src/builder.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/builder.jl b/src/builder.jl index 19fbd81..938fee7 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -136,8 +136,8 @@ function xx2AA(spec_nlm, radial::Radial_basis; categories=[], d=3, rSH = false) # from spec_nlm to all possible spec1p spec1p, lmax, nmax = specnlm2spec1p(spec_nlm) # An assertation whether all the radial specs are in spec1p - @assert issubset(radial.Radialspec, nset(spec1p)) || issubset(radial.Radialspec, nlset(spec1p)) - + @assert issubset(nset(spec1p), radial.Radialspec) || issubset(nlset(spec1p), radial.Radialspec) + dict_spec1p = Dict([spec1p[i] => i for i = 1:length(spec1p)]) Ylm = rSH ? RYlmBasis(lmax) : CYlmBasis(lmax) # Rn = radial_basis(nmax) From 085263c9485c12c967c2b68fd095170f9f05bbfa Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Thu, 5 Oct 2023 10:43:56 -0700 Subject: [PATCH 18/18] modify RPE_filter --- src/builder.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/builder.jl b/src/builder.jl index 938fee7..1ecc941 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -13,10 +13,10 @@ export equivariant_model, equivariant_SYY_model, equivariant_luxchain_constructo P4ML = Polynomials4ML -RPE_filter(L) = bb -> (length(bb) == 0) || ((abs(sum(b.m for b in bb)) <= L) && iseven(sum(b.l for b in bb)+L)) +RPE_filter(L) = bb -> (length(bb) == 0) || ((abs(sum(b.m for b in bb)) <= L) && iseven(sum(b.l for b in bb)+L)) && ( length(bb) == 1 && L == 0 ? bb[1].l == 0 : true ) RPE_filter_long(L) = bb -> (length(bb) == 0) || (abs(sum(b.m for b in bb)) <= L) -RPE_filter_real(L) = bb -> (length(bb) == 0) || mm_filter([b.m for b in bb],L) && iseven(sum(b.l for b in bb)+L) +RPE_filter_real(L) = bb -> (length(bb) == 0) || mm_filter([b.m for b in bb],L) && iseven(sum(b.l for b in bb)+L) && ( length(bb) == 1 && L == 0 ? bb[1].l == 0 : true ) """ _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T},Rot3DCoeffs_real{L,T},Rot3DCoeffs_long{L,T}},