From 1d58770c216c06d4e297630985dfffa00ba775b3 Mon Sep 17 00:00:00 2001 From: jerryho Date: Sun, 24 Sep 2023 01:02:28 -0400 Subject: [PATCH 01/14] fixing type ambguitity of CategorialBasis --- src/categorical.jl | 21 +++++++++++++++------ test/test_categorial.jl | 3 +++ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/categorical.jl b/src/categorical.jl index 3cd65ed..d7b5c82 100644 --- a/src/categorical.jl +++ b/src/categorical.jl @@ -1,6 +1,8 @@ using Polynomials4ML -using Polynomials4ML: AbstractPoly4MLBasis -using StaticArrays: SVector +using Polynomials4ML: AbstractPoly4MLBasis, SphericalCoords +using StaticArrays: SVector, StaticArray + +import Polynomials4ML: _valtype, _out_size, _outsym export CategoricalBasis @@ -91,10 +93,17 @@ CategoricalBasis(categories::AbstractArray, meta = Dict{String, Any}() ) = #Polynomials4ML._outsym(x::Char) = :char -Polynomials4ML._outsym(x::T) where T = :T -Polynomials4ML._out_size(basis::CategoricalBasis{LEN, T}, x::T) where {LEN, T} = (LEN,) -Polynomials4ML._out_size(basis::CategoricalBasis{LEN, T}, x::Vector{T}) where {LEN, T} = (length(x),LEN) -Polynomials4ML._valtype(basis::CategoricalBasis{LEN, T}, x::Union{T,Vector{T}}) where {LEN, T} = Bool +_outsym(x::T) where T = :T + +const NSS = Union{Number, SphericalCoords, StaticArray} + +_out_size(basis::CategoricalBasis{LEN, T}, x::T) where {LEN, T} = (LEN,) +_out_size(basis::CategoricalBasis{LEN, T}, x::Vector{T}) where {LEN, T} = (length(x), LEN) +_out_size(basis::CategoricalBasis{LEN, T}, x::NSS) where {LEN, T <: NSS} = (LEN, ) + +_valtype(basis::CategoricalBasis{LEN, T}, x::Union{T,Vector{T}}) where {LEN, T} = Bool +_valtype(basis::CategoricalBasis{LEN, T}, x::NSS) where {LEN, T <: NSS} = Bool +_valtype(basis::CategoricalBasis{LEN, T}, x::Vector{<:NSS}) where {LEN, T <: NSS} = Bool # should the output be somethign like this? # struct Ei diff --git a/test/test_categorial.jl b/test/test_categorial.jl index d5e7e33..907b43b 100644 --- a/test/test_categorial.jl +++ b/test/test_categorial.jl @@ -17,6 +17,8 @@ for (i, c) in enumerate(elements) print_tf(@test (val2i(slist, c) == i) ) end +println() + ## simply a basis @info("Testing Categorical Basis") @@ -40,3 +42,4 @@ for c in elements l_out, st2 = l_catbasis(c, ps, st) println_slim(@test out == l_out) end + From 9db8f72165f073fb42d347c2a6b6f6160b8c7e01 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Sat, 23 Sep 2023 23:49:23 -0700 Subject: [PATCH 02/14] Minor modification on tests for Chain with CategoricalBasis --- test/test_equiv_with_cate.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_equiv_with_cate.jl b/test/test_equiv_with_cate.jl index 65ed6b4..7779600 100644 --- a/test/test_equiv_with_cate.jl +++ b/test/test_equiv_with_cate.jl @@ -11,8 +11,8 @@ Aspec, AAspec = degord2spec(; totaldegree = 4, Lmax = 0, ) cats = [:O,:C] -Aspec_tmp = [ [ (Aspec[i]..., s = cats[1]) for i = 1 : length(Aspec)]...; [ (Aspec[i]..., s = cats[2]) for i = 1 : length(Aspec)]...] -AAspec_tmp = [[Aspec_tmp[i]] for i = 1:length(Aspec_tmp)] +ext(x,cats) = [ (x[i]..., s = cats) for i = 1:length(x)] +AAspec_tmp = [ ext.(AAspec,cats[1])..., ext.(AAspec,cats[2])... ] |> sort luxchain, ps, st = equivariant_model(AAspec_tmp, L; categories=cats) F(X) = luxchain(X, ps, st)[1] From 0fb179d98fb67f573d4aec81fdf6f95fa8ebaf2a Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Sun, 24 Sep 2023 00:02:28 -0700 Subject: [PATCH 03/14] Make the above test a bit more challenging --- test/test_equiv_with_cate.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/test_equiv_with_cate.jl b/test/test_equiv_with_cate.jl index 7779600..8d646fc 100644 --- a/test/test_equiv_with_cate.jl +++ b/test/test_equiv_with_cate.jl @@ -13,6 +13,11 @@ cats = [:O,:C] ext(x,cats) = [ (x[i]..., s = cats) for i = 1:length(x)] AAspec_tmp = [ ext.(AAspec,cats[1])..., ext.(AAspec,cats[2])... ] |> sort +pos = findall(x -> length(x)>1, AAspec) +_AAspec_tmp = [ [(AAspec[i][1]..., s = cats[1]), (AAspec[i][2]..., s = cats[2])] for i in pos ] +_AAspec_tmp2 = [ [(AAspec[i][1]..., s = cats[2]), (AAspec[i][2]..., s = cats[1])] for i in pos ] +append!(AAspec_tmp,_AAspec_tmp) +append!(AAspec_tmp,_AAspec_tmp2) luxchain, ps, st = equivariant_model(AAspec_tmp, L; categories=cats) F(X) = luxchain(X, ps, st)[1] From fd639d01edcd773ed4a4ce3519133a99bb7f3681 Mon Sep 17 00:00:00 2001 From: cheukhinhojerry Date: Sun, 24 Sep 2023 14:12:41 -0700 Subject: [PATCH 04/14] Serialization from JuLIP for model with cate with z0zj --- examples/potential/forces_with_cate.jl | 172 +++++++++++++++++++++++++ 1 file changed, 172 insertions(+) create mode 100644 examples/potential/forces_with_cate.jl diff --git a/examples/potential/forces_with_cate.jl b/examples/potential/forces_with_cate.jl new file mode 100644 index 0000000..9acae55 --- /dev/null +++ b/examples/potential/forces_with_cate.jl @@ -0,0 +1,172 @@ +using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote +using Polynomials4ML: LinearLayer, RYlmBasis, lux +using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA +using JuLIP + +rng = Random.MersenneTwister() + +## +include("xx2AA.jl") + +# == configs and form model == +rcut = 5.5 +maxL = 0 +L = 4 +Aspec, AAspec = degord2spec(; totaldegree = 4, + order = 2, + Lmax = 0, ) +cats = AtomicNumber.([:W, :W]) + +new_spec = [] +ori_AAspec = deepcopy(AAspec) +new_AAspec = [] + +for bb in ori_AAspec + newbb = [] + for t in bb + push!(newbb, (t..., s = cats)) + end + push!(new_AAspec, newbb) +end + +luxchain, ps, st = equivariant_model(new_AAspec, L; categories=cats) + +#LL, ps, st, try_xnxz = myxx2AA(new_AAspec; categories = cats) +#tryps, tryst = Lux.setup(MersenneTwister(1234), try_xnxz) +#try_xnxz(X, tryps, tryst) + + +## + +# == init example data == + +at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1) +nlist = JuLIP.neighbourlist(at, rcut) +_, Rs, Zs = JuLIP.Potentials.neigsz(nlist, at, 1) +# centere atom +z0 = at.Z[1] + +# serialization, I want the input data structure to lux as simple as possible +get_Z0S(zz0, ZZS) = [SVector{2}(zz0, zzs) for zzs in ZZS] +Z0S = get_Z0S(z0, Zs) + +# input of luxmodel +X = (Rs, Z0S) + +out, st = luxchain(X, ps, st) + +# === + + +# testing derivative (forces) +# g = Zygote.gradient(X -> model(X, ps, st)[1], X)[1] + +## + +module Pot + import JuLIP, Zygote + import JuLIP: cutoff, Atoms + import ACEbase: evaluate!, evaluate_d! + + import ChainRulesCore + import ChainRulesCore: rrule, ignore_derivatives + + import Optimisers: destructure + + struct LuxCalc <: JuLIP.SitePotential + luxmodel + ps + st + rcut::Float64 + restructure + end + + function LuxCalc(luxmodel, ps, st, rcut) + pvec, rest = destructure(ps) + return LuxCalc(luxmodel, ps, st, rcut, rest) + end + + cutoff(calc::LuxCalc) = calc.rcut + + function evaluate!(tmp, calc::LuxCalc, Rs, Zs, z0) + E, st = calc.luxmodel(Rs, calc.ps, calc.st) + return E[1] + end + + function evaluate_d!(dEs, tmpd, calc::LuxCalc, Rs, Zs, z0) + g = Zygote.gradient(X -> calc.luxmodel(X, calc.ps, calc.st)[1], Rs)[1] + @assert length(g) == length(Rs) <= length(dEs) + dEs[1:length(g)] .= g + return dEs + end + + # ----- parameter estimation stuff + + + function lux_energy(at::Atoms, calc::LuxCalc, ps::NamedTuple, st::NamedTuple) + nlist = ignore_derivatives() do + JuLIP.neighbourlist(at, calc.rcut) + end + return sum( i -> begin + Js, Rs, Zs = ignore_derivatives() do + JuLIP.Potentials.neigsz(nlist, at, i) + end + Ei, st = calc.luxmodel(Rs, ps, st) + Ei[1] + end, + 1:length(at) + ) + end + + # function rrule(::typeof(lux_energy), at::Atoms, calc::LuxCalc, ps::NamedTuple, st::NamedTuple) + # E = lux_energy(at, calc, ps, st) + # function pb(Δ) + # nlist = JuLIP.neighbourlist(at, calc.rcut) + # @show Δ + # error("stop") + # function pb_inner(i) + # Js, Rs, Zs = JuLIP.Potentials.neigsz(nlist, at, i) + # gi = ReverseDiff.gradient() + # end + # for i = 1:length(at) + # Ei, st = calc.luxmodel(Rs, calc.ps, calc.st) + # E += Ei[1] + # end + # end + # end + +end + +## + +using JuLIP +JuLIP.usethreads!(false) +ps.dot.W[:] .= 0.01 * randn(length(ps.dot.W)) + +at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1) +calc = Pot.LuxCalc(model, ps, st, rcut) +JuLIP.energy(calc, at) +JuLIP.forces(calc, at) +JuLIP.virial(calc, at) +Pot.lux_energy(at, calc, ps, st) + +@time JuLIP.energy(calc, at) +@time Pot.lux_energy(at, calc, ps, st) +@time JuLIP.forces(calc, at) + +## + +using Optimisers, ReverseDiff + +p_vec, _rest = destructure(ps) +f(_pvec) = Pot.lux_energy(at, calc, _rest(_pvec), st) + +f(p_vec) +g = Zygote.gradient(f, p_vec)[1] + +@time f(p_vec) +@time Zygote.gradient(f, p_vec)[1] + + +# This fails for now +# gr = ReverseDiff.gradient(f, p_vec)[1] From c7a991cab87c9bc114e2e252b77083600a667e7b Mon Sep 17 00:00:00 2001 From: cheukhinhojerry Date: Sun, 24 Sep 2023 14:52:49 -0700 Subject: [PATCH 05/14] minor fix --- examples/potential/forces_with_cate.jl | 10 +++------- src/builder.jl | 19 +++++++++++-------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/examples/potential/forces_with_cate.jl b/examples/potential/forces_with_cate.jl index 9acae55..ef815f6 100644 --- a/examples/potential/forces_with_cate.jl +++ b/examples/potential/forces_with_cate.jl @@ -6,7 +6,6 @@ using JuLIP rng = Random.MersenneTwister() ## -include("xx2AA.jl") # == configs and form model == rcut = 5.5 @@ -31,11 +30,6 @@ end luxchain, ps, st = equivariant_model(new_AAspec, L; categories=cats) -#LL, ps, st, try_xnxz = myxx2AA(new_AAspec; categories = cats) -#tryps, tryst = Lux.setup(MersenneTwister(1234), try_xnxz) -#try_xnxz(X, tryps, tryst) - - ## # == init example data == @@ -55,11 +49,13 @@ X = (Rs, Z0S) out, st = luxchain(X, ps, st) + + # === # testing derivative (forces) -# g = Zygote.gradient(X -> model(X, ps, st)[1], X)[1] +g = Zygote.gradient(X -> model(X, ps, st)[1], X)[1] ## diff --git a/src/builder.jl b/src/builder.jl index 2090761..61fa38a 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -8,6 +8,7 @@ using Lux using Random using Polynomials4ML using StaticArrays +using Combinatorics: permutations export equivariant_model, equivariant_SYY_model, equivariant_luxchain_constructor, equivariant_luxchain_constructor_new @@ -49,7 +50,7 @@ function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Ro nn = SVector([onep.n for onep in pib]...) ll = SVector([onep.l for onep in pib]...) # get a SVector of ll index if haskey(pib[1],:s) - ss = SVector([onep.s for onep in pib]...) + ss = [onep.s for onep in pib] end if haskey(pib[1],:s) @@ -136,11 +137,12 @@ function xx2AA(spec_nlm; categories=[], d=3, radial_basis = legendre_basis) # Co if !isempty(categories) # Read categories from x - TODO: discuss which format we like it to be... # For now we just give get_cat(x) a random value - get_cat(x) = length(categories) > 1 ? (iseven(floor(norm(x))) ? categories[1] : categories[2]) : categories[1] - _get_cat(x) = get_cat.(x) + #get_cat(x) = length(categories) > 1 ? (iseven(floor(norm(x))) ? categories[1] : categories[2]) : categories[1] + #_get_cat(x) = get_cat.(x) # Define categorical bases - δs = CategoricalBasis(categories) + cat_perm2 = collect(SVector{2}.(permutations(categories, 2))) + δs = CategoricalBasis(cat_perm2) l_δs = P4ML.lux(δs) end @@ -171,14 +173,15 @@ function xx2AA(spec_nlm; categories=[], d=3, radial_basis = legendre_basis) # Co l_xnx = Lux.Parallel(nothing; normx = WrappedFunction(_norm), x = WrappedFunction(identity)) l_embed = Lux.Parallel(nothing; Rn = l_Rn, Ylm = l_Ylm) else - l_xnx = Lux.Parallel(nothing; normx = WrappedFunction(_norm), x = WrappedFunction(identity), catlist = WrappedFunction(_get_cat)) + l_xnxz = Lux.BranchLayer(normx = WrappedFunction(x -> _norm(x[1])), x = WrappedFunction(x -> x[1]), catlist = WrappedFunction(x -> x[2])) + #l_xnx = Lux.Parallel(nothing; normx = WrappedFunction(_norm), x = WrappedFunction(identity), catlist = WrappedFunction(_get_cat)) l_embed = Lux.Parallel(nothing; Rn = l_Rn, Ylm = l_Ylm, δs = l_δs) end - luxchain = Chain(xnx = l_xnx, embed = l_embed, A = l_bA , AA = l_bAA, AA_sort = WrappedFunction(x -> x[pos])) + luxchain = Chain(l_xnxz = l_xnxz, embed = l_embed, A = l_bA , AA = l_bAA, AA_sort = WrappedFunction(x -> x[pos])) ps, st = Lux.setup(MersenneTwister(1234), luxchain) - return luxchain, ps, st + return luxchain, ps, st, l_xnxz end """ @@ -197,7 +200,7 @@ function equivariant_model(spec_nlm, L::Int64; categories=[], d=3, radial_basis= # sort!(spec_nlm, by = x -> length(x)) spec_nlm = closure(spec_nlm,filter_init; categories = categories) - luxchain_tmp, ps_tmp, st_tmp = EquivariantModels.xx2AA(spec_nlm; categories = categories, d = d, radial_basis = radial_basis) + luxchain_tmp, ps_tmp, st_tmp, _ = EquivariantModels.xx2AA(spec_nlm; categories = categories, d = d, radial_basis = radial_basis) F(X) = luxchain_tmp(X, ps_tmp, st_tmp)[1] if islong From d5d61bfb2a419c833421587bd77631c5e8713210 Mon Sep 17 00:00:00 2001 From: cheukhinhojerry Date: Sun, 24 Sep 2023 15:14:50 -0700 Subject: [PATCH 06/14] minor fix --- src/builder.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/builder.jl b/src/builder.jl index 61fa38a..c9b492f 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -172,16 +172,17 @@ function xx2AA(spec_nlm; categories=[], d=3, radial_basis = legendre_basis) # Co if isempty(categories) l_xnx = Lux.Parallel(nothing; normx = WrappedFunction(_norm), x = WrappedFunction(identity)) l_embed = Lux.Parallel(nothing; Rn = l_Rn, Ylm = l_Ylm) + luxchain = Chain(l_xnx = l_xnx, embed = l_embed, A = l_bA , AA = l_bAA, AA_sort = WrappedFunction(x -> x[pos])) else l_xnxz = Lux.BranchLayer(normx = WrappedFunction(x -> _norm(x[1])), x = WrappedFunction(x -> x[1]), catlist = WrappedFunction(x -> x[2])) - #l_xnx = Lux.Parallel(nothing; normx = WrappedFunction(_norm), x = WrappedFunction(identity), catlist = WrappedFunction(_get_cat)) l_embed = Lux.Parallel(nothing; Rn = l_Rn, Ylm = l_Ylm, δs = l_δs) + luxchain = Chain(l_xnxz = l_xnxz, embed = l_embed, A = l_bA , AA = l_bAA, AA_sort = WrappedFunction(x -> x[pos])) end - luxchain = Chain(l_xnxz = l_xnxz, embed = l_embed, A = l_bA , AA = l_bAA, AA_sort = WrappedFunction(x -> x[pos])) + # luxchain = Chain(l_xnxz = l_xnxz, embed = l_embed, A = l_bA , AA = l_bAA, AA_sort = WrappedFunction(x -> x[pos])) ps, st = Lux.setup(MersenneTwister(1234), luxchain) - return luxchain, ps, st, l_xnxz + return luxchain, ps, st end """ @@ -200,7 +201,7 @@ function equivariant_model(spec_nlm, L::Int64; categories=[], d=3, radial_basis= # sort!(spec_nlm, by = x -> length(x)) spec_nlm = closure(spec_nlm,filter_init; categories = categories) - luxchain_tmp, ps_tmp, st_tmp, _ = EquivariantModels.xx2AA(spec_nlm; categories = categories, d = d, radial_basis = radial_basis) + luxchain_tmp, ps_tmp, st_tmp = EquivariantModels.xx2AA(spec_nlm; categories = categories, d = d, radial_basis = radial_basis) F(X) = luxchain_tmp(X, ps_tmp, st_tmp)[1] if islong From b397cc63dbf0c8d6286c31e0147e8caaa0a22012 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Sun, 24 Sep 2023 15:45:34 -0700 Subject: [PATCH 07/14] Minor changes on categories --- src/builder.jl | 4 +--- test/runtests.jl | 3 ++- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/builder.jl b/src/builder.jl index c9b492f..7fc7fde 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -8,7 +8,6 @@ using Lux using Random using Polynomials4ML using StaticArrays -using Combinatorics: permutations export equivariant_model, equivariant_SYY_model, equivariant_luxchain_constructor, equivariant_luxchain_constructor_new @@ -141,8 +140,7 @@ function xx2AA(spec_nlm; categories=[], d=3, radial_basis = legendre_basis) # Co #_get_cat(x) = get_cat.(x) # Define categorical bases - cat_perm2 = collect(SVector{2}.(permutations(categories, 2))) - δs = CategoricalBasis(cat_perm2) + δs = CategoricalBasis(categories) l_δs = P4ML.lux(δs) end diff --git a/test/runtests.jl b/test/runtests.jl index 729f071..8f79463 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,5 +3,6 @@ using Test @testset "EquivariantModels.jl" begin @testset "CategoricalBasis" begin include("test_categorial.jl") end - @testset "Equivariance" begin include("test_equivariance.jl"); include("test_equiv_with_cate.jl"); end + # @testset "Equivariance" begin include("test_equivariance.jl"); include("test_equiv_with_cate.jl"); end + @testset "Equivariance" begin include("test_equivariance.jl") end end From 75969d1ab0b3bdd318b31d3e8e3b733ee50a6766 Mon Sep 17 00:00:00 2001 From: cheukhinhojerry Date: Sun, 24 Sep 2023 16:37:44 -0700 Subject: [PATCH 08/14] add gradient working example, quick workaround for some problem in P4ML to be discussed --- examples/potential/forces_with_cate.jl | 18 ++++++++++++++---- examples/potential/staticprod.jl | 25 +++++++++++++++++++++++++ src/categorical.jl | 12 +++++++++++- 3 files changed, 50 insertions(+), 5 deletions(-) create mode 100644 examples/potential/staticprod.jl diff --git a/examples/potential/forces_with_cate.jl b/examples/potential/forces_with_cate.jl index ef815f6..d9c0fab 100644 --- a/examples/potential/forces_with_cate.jl +++ b/examples/potential/forces_with_cate.jl @@ -2,9 +2,11 @@ using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote using Polynomials4ML: LinearLayer, RYlmBasis, lux using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA using JuLIP +using Combinatorics: permutations rng = Random.MersenneTwister() +include("staticprod.jl") ## # == configs and form model == @@ -28,7 +30,9 @@ for bb in ori_AAspec push!(new_AAspec, newbb) end -luxchain, ps, st = equivariant_model(new_AAspec, L; categories=cats) +cat_perm2 = collect(SVector{2}.(permutations(cats, 2))) + +luxchain, ps, st = equivariant_model(new_AAspec, 0; categories = cat_perm2, islong = false) ## @@ -47,15 +51,21 @@ Z0S = get_Z0S(z0, Zs) # input of luxmodel X = (Rs, Z0S) -out, st = luxchain(X, ps, st) +# == lux chain eval and grad +out, st = luxchain(X, ps, st) +B = out +model = append_layers(luxchain, get1 = WrappedFunction(t -> real.(t)), dot = LinearLayer(length(B), 1), get2 = WrappedFunction(t -> t[1])) -# === +ps, st = Lux.setup(MersenneTwister(1234), model) +model(X, ps, st) # testing derivative (forces) -g = Zygote.gradient(X -> model(X, ps, st)[1], X)[1] +g = Zygote.gradient(X -> model(X, ps, st)[1], X)[1] + + ## diff --git a/examples/potential/staticprod.jl b/examples/potential/staticprod.jl new file mode 100644 index 0000000..578fe7c --- /dev/null +++ b/examples/potential/staticprod.jl @@ -0,0 +1,25 @@ +import Polynomials4ML: _static_prod_ed, _pb_grad_static_prod + +function _static_prod_ed(b::NTuple{N, Any}) where N + b2 = b[2:N] + p2, g2 = _static_prod_ed(b2) + return b[1] * p2, tuple(p2, ntuple(i -> b[1] * g2[i], N-1)...) +end + +function _static_prod_ed(b::NTuple{1, Any}) + return b[1], (one(T),) +end + +function _pb_grad_static_prod(∂::NTuple{N, Any}, b::NTuple{N, Any}) where N + ∂2 = ∂[2:N] + b2 = b[2:N] + p2, g2, u2 = _pb_grad_static_prod(∂2, b2) + return b[1] * p2, + tuple(p2, ntuple(i -> b[1] * g2[i], N-1)...), + tuple(sum(∂2 .* g2), ntuple(i -> ∂[1] * g2[i] + b[1] * u2[i], N-1)...) + end + +function _pb_grad_static_prod(∂::NTuple{1, Any}, b::NTuple{1, Any}) + return b[1], (one(T),), (zero(T),) +end + \ No newline at end of file diff --git a/src/categorical.jl b/src/categorical.jl index d7b5c82..065ad2b 100644 --- a/src/categorical.jl +++ b/src/categorical.jl @@ -2,7 +2,8 @@ using Polynomials4ML using Polynomials4ML: AbstractPoly4MLBasis, SphericalCoords using StaticArrays: SVector, StaticArray -import Polynomials4ML: _valtype, _out_size, _outsym +import Polynomials4ML: _valtype, _out_size, _outsym, evaluate, evaluate! +import ChainRulesCore: rrule, NoTangent export CategoricalBasis @@ -146,6 +147,15 @@ Polynomials4ML.natural_indices(basis::CategoricalBasis) = basis.categories.list Base.rand(basis::CategoricalBasis) = rand(basis.categories) +## rrule +function rrule(::typeof(evaluate), basis::CategoricalBasis, x) + A = evaluate(basis, x) + function pb(x) + return NoTangent(), NoTangent(), NoTangent() + end + return A, pb +end + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # probably we don't need the rest, but keep around for now From f8f7fd7a4f16ab607c44786bc2a702f3ffb8bd0e Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Sun, 24 Sep 2023 16:53:29 -0700 Subject: [PATCH 09/14] Re-enable tests for equivariant_cat_included_chain --- examples/potential/forces_with_cate_loc.jl | 170 +++++++++++++++++++++ test/runtests.jl | 3 +- test/test_equiv_with_cate.jl | 24 ++- 3 files changed, 187 insertions(+), 10 deletions(-) create mode 100644 examples/potential/forces_with_cate_loc.jl diff --git a/examples/potential/forces_with_cate_loc.jl b/examples/potential/forces_with_cate_loc.jl new file mode 100644 index 0000000..cb22bda --- /dev/null +++ b/examples/potential/forces_with_cate_loc.jl @@ -0,0 +1,170 @@ +using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote +using Polynomials4ML: LinearLayer, RYlmBasis, lux +using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA +using JuLIP +using Combinatorics: permutations + +rng = Random.MersenneTwister() + +## + +# == configs and form model == +rcut = 5.5 +maxL = 0 +L = 4 +Aspec, AAspec = degord2spec(; totaldegree = 4, + order = 2, + Lmax = 0, ) +cats = AtomicNumber.([:W, :W]) +cat = collect(SVector{2}.(permutations(categories, 2))) + +new_spec = [] +ori_AAspec = deepcopy(AAspec) +new_AAspec = [] + +for bb in ori_AAspec + newbb = [] + for t in bb + push!(newbb, (t..., s = cats)) + end + push!(new_AAspec, newbb) +end + +luxchain, ps, st = equivariant_model(new_AAspec, L; categories=cats) + +## + +# == init example data == + +at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1) +nlist = JuLIP.neighbourlist(at, rcut) +_, Rs, Zs = JuLIP.Potentials.neigsz(nlist, at, 1) +# centere atom +z0 = at.Z[1] + +# serialization, I want the input data structure to lux as simple as possible +get_Z0S(zz0, ZZS) = [SVector{2}(zz0, zzs) for zzs in ZZS] +Z0S = get_Z0S(z0, Zs) + +# input of luxmodel +X = (Rs, Z0S) + +out, st = luxchain(X, ps, st) + + + +# === + + +# testing derivative (forces) +g = Zygote.gradient(X -> model(X, ps, st)[1], X)[1] + +## + +module Pot + import JuLIP, Zygote + import JuLIP: cutoff, Atoms + import ACEbase: evaluate!, evaluate_d! + + import ChainRulesCore + import ChainRulesCore: rrule, ignore_derivatives + + import Optimisers: destructure + + struct LuxCalc <: JuLIP.SitePotential + luxmodel + ps + st + rcut::Float64 + restructure + end + + function LuxCalc(luxmodel, ps, st, rcut) + pvec, rest = destructure(ps) + return LuxCalc(luxmodel, ps, st, rcut, rest) + end + + cutoff(calc::LuxCalc) = calc.rcut + + function evaluate!(tmp, calc::LuxCalc, Rs, Zs, z0) + E, st = calc.luxmodel(Rs, calc.ps, calc.st) + return E[1] + end + + function evaluate_d!(dEs, tmpd, calc::LuxCalc, Rs, Zs, z0) + g = Zygote.gradient(X -> calc.luxmodel(X, calc.ps, calc.st)[1], Rs)[1] + @assert length(g) == length(Rs) <= length(dEs) + dEs[1:length(g)] .= g + return dEs + end + + # ----- parameter estimation stuff + + + function lux_energy(at::Atoms, calc::LuxCalc, ps::NamedTuple, st::NamedTuple) + nlist = ignore_derivatives() do + JuLIP.neighbourlist(at, calc.rcut) + end + return sum( i -> begin + Js, Rs, Zs = ignore_derivatives() do + JuLIP.Potentials.neigsz(nlist, at, i) + end + Ei, st = calc.luxmodel(Rs, ps, st) + Ei[1] + end, + 1:length(at) + ) + end + + # function rrule(::typeof(lux_energy), at::Atoms, calc::LuxCalc, ps::NamedTuple, st::NamedTuple) + # E = lux_energy(at, calc, ps, st) + # function pb(Δ) + # nlist = JuLIP.neighbourlist(at, calc.rcut) + # @show Δ + # error("stop") + # function pb_inner(i) + # Js, Rs, Zs = JuLIP.Potentials.neigsz(nlist, at, i) + # gi = ReverseDiff.gradient() + # end + # for i = 1:length(at) + # Ei, st = calc.luxmodel(Rs, calc.ps, calc.st) + # E += Ei[1] + # end + # end + # end + +end + +## + +using JuLIP +JuLIP.usethreads!(false) +ps.dot.W[:] .= 0.01 * randn(length(ps.dot.W)) + +at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1) +calc = Pot.LuxCalc(model, ps, st, rcut) +JuLIP.energy(calc, at) +JuLIP.forces(calc, at) +JuLIP.virial(calc, at) +Pot.lux_energy(at, calc, ps, st) + +@time JuLIP.energy(calc, at) +@time Pot.lux_energy(at, calc, ps, st) +@time JuLIP.forces(calc, at) + +## + +using Optimisers, ReverseDiff + +p_vec, _rest = destructure(ps) +f(_pvec) = Pot.lux_energy(at, calc, _rest(_pvec), st) + +f(p_vec) +g = Zygote.gradient(f, p_vec)[1] + +@time f(p_vec) +@time Zygote.gradient(f, p_vec)[1] + + +# This fails for now +# gr = ReverseDiff.gradient(f, p_vec)[1] diff --git a/test/runtests.jl b/test/runtests.jl index 8f79463..729f071 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,5 @@ using Test @testset "EquivariantModels.jl" begin @testset "CategoricalBasis" begin include("test_categorial.jl") end - # @testset "Equivariance" begin include("test_equivariance.jl"); include("test_equiv_with_cate.jl"); end - @testset "Equivariance" begin include("test_equivariance.jl") end + @testset "Equivariance" begin include("test_equivariance.jl"); include("test_equiv_with_cate.jl"); end end diff --git a/test/test_equiv_with_cate.jl b/test/test_equiv_with_cate.jl index 8d646fc..5d90182 100644 --- a/test/test_equiv_with_cate.jl +++ b/test/test_equiv_with_cate.jl @@ -1,6 +1,7 @@ using Polynomials4ML, StaticArrays, EquivariantModels, Test, Rotations, LinearAlgebra using ACEbase.Testing: print_tf using EquivariantModels: getspec1idx, _invmap, dropnames, SList, val2i, xx2AA, degord2spec +using Combinatorics: permutations include("wigner.jl") @@ -10,34 +11,41 @@ Aspec, AAspec = degord2spec(; totaldegree = 4, order = 2, Lmax = 0, ) cats = [:O,:C] - -ext(x,cats) = [ (x[i]..., s = cats) for i = 1:length(x)] -AAspec_tmp = [ ext.(AAspec,cats[1])..., ext.(AAspec,cats[2])... ] |> sort +cats_ext = [(:O,:C),(:C,:O),(:O,:O),(:C,:C)] |> unique +AAspec_tmp = [] +for i = 1:length(AAspec) + push!(AAspec_tmp, [ (spec..., s = cats_ext[1]) for spec in AAspec[i] ]) + push!(AAspec_tmp, [ (spec..., s = cats_ext[2]) for spec in AAspec[i] ]) +end pos = findall(x -> length(x)>1, AAspec) -_AAspec_tmp = [ [(AAspec[i][1]..., s = cats[1]), (AAspec[i][2]..., s = cats[2])] for i in pos ] -_AAspec_tmp2 = [ [(AAspec[i][1]..., s = cats[2]), (AAspec[i][2]..., s = cats[1])] for i in pos ] +_AAspec_tmp = [ [(AAspec[i][1]..., s = cats_ext[1]), (AAspec[i][2]..., s = cats_ext[2])] for i in pos ] +_AAspec_tmp2 = [ [(AAspec[i][1]..., s = cats_ext[2]), (AAspec[i][2]..., s = cats_ext[1])] for i in pos ] append!(AAspec_tmp,_AAspec_tmp) append!(AAspec_tmp,_AAspec_tmp2) -luxchain, ps, st = equivariant_model(AAspec_tmp, L; categories=cats) +luxchain, ps, st = equivariant_model(AAspec_tmp, L; categories=cats_ext) F(X) = luxchain(X, ps, st)[1] +species = [ rand(cats) for i = 1:10 ] +Species = [ (species[1], species[i]) for i = 1:10 ] @info("Testing the equivariance of chains that contain categorical basis") for ntest = 1:10 local X, θ1, θ2, θ3, Q, QX X = [ @SVector(rand(3)) for i in 1:10 ] + XX = (X, Species) θ1 = rand() * 2pi θ2 = rand() * 2pi θ3 = rand() * 2pi Q = RotXYZ(θ1, θ2, θ3) QX = [SVector{3}(x) for x in Ref(Q) .* X] + QXX = (QX, Species) - print_tf(@test F(X)[1] ≈ F(QX)[1]) + print_tf(@test F(XX)[1] ≈ F(QXX)[1]) for l = 2:L D = wigner_D(l-1,Matrix(Q))' # D = wignerD(l-1, 0, 0, θ) - print_tf(@test norm.(Ref(D') .* F(X)[l] - F(QX)[l]) |> norm <1e-12) + print_tf(@test Ref(D') .* F(XX)[l] ≈ F(QXX)[l]) end end From 407a02bcf8170b8d46c76ee6f1600fca65193461 Mon Sep 17 00:00:00 2001 From: YangshuaiWang Date: Sun, 24 Sep 2023 17:04:25 -0700 Subject: [PATCH 10/14] initial test of energy with multi species --- .../potential/energy_with_multi_species.jl | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 examples/potential/energy_with_multi_species.jl diff --git a/examples/potential/energy_with_multi_species.jl b/examples/potential/energy_with_multi_species.jl new file mode 100644 index 0000000..1b3d8c5 --- /dev/null +++ b/examples/potential/energy_with_multi_species.jl @@ -0,0 +1,52 @@ +using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote +using Polynomials4ML: LinearLayer, RYlmBasis, lux +using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA +using JuLIP, Combinatorics + +rng = Random.MersenneTwister() + +rcut = 5.5 +maxL = 0 +L = 0 +Aspec, AAspec = degord2spec(; totaldegree = 6, + order = 3, + Lmax = 0, ) +cats = AtomicNumber.([:W, :Cu, :Ni, :Fe, :Al]) +ipairs = collect(Combinatorics.permutations(1:length(cats), 2)) +allcats = collect(SVector{2}.(Combinatorics.permutations(cats, 2))) + +for (i, cat) in enumerate(cats) + push!(ipairs, [i, i]) + push!(allcats, SVector{2}([cat, cat])) +end + +new_spec = [] +ori_AAspec = deepcopy(AAspec) +new_AAspec = [] + +for bb in ori_AAspec + newbb = [] + for (t, ip) in zip(bb, ipairs) + push!(newbb, (t..., s = cats[ip])) + end + push!(new_AAspec, newbb) +end + +luxchain, ps, st = equivariant_model(new_AAspec, L; categories=allcats) + +at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1) +iCu = [5, 12]; iNi = [3, 8]; iAl = [10]; iFe = [6]; +at.Z[iCu] .= cats[2]; at.Z[iNi] .= cats[3]; at.Z[iAl] .= cats[4]; at.Z[iFe] .= cats[5]; +nlist = JuLIP.neighbourlist(at, rcut) +_, Rs, Zs = JuLIP.Potentials.neigsz(nlist, at, 1) +# centere atom +z0 = at.Z[1] + +# serialization, I want the input data structure to lux as simple as possible +get_Z0S(zz0, ZZS) = [SVector{2}(zz0, zzs) for zzs in ZZS] +Z0S = get_Z0S(z0, Zs) + +# input of luxmodel +X = (Rs, Z0S) + +out, st = luxchain(X, ps, st) From 7d2ba046d0672f382403c20331d2f01a4b06cee9 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Sun, 24 Sep 2023 17:04:42 -0700 Subject: [PATCH 11/14] Typo fix --- test/test_equiv_with_cate.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_equiv_with_cate.jl b/test/test_equiv_with_cate.jl index 5d90182..fb49796 100644 --- a/test/test_equiv_with_cate.jl +++ b/test/test_equiv_with_cate.jl @@ -1,7 +1,6 @@ using Polynomials4ML, StaticArrays, EquivariantModels, Test, Rotations, LinearAlgebra using ACEbase.Testing: print_tf using EquivariantModels: getspec1idx, _invmap, dropnames, SList, val2i, xx2AA, degord2spec -using Combinatorics: permutations include("wigner.jl") From 995e2b76914a5dbd64a5d097eeb7bcc95491d920 Mon Sep 17 00:00:00 2001 From: cheukhinhojerry Date: Sun, 24 Sep 2023 17:09:32 -0700 Subject: [PATCH 12/14] multi cate JuLIP.SitePot working --- examples/potential/forces_with_cate.jl | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/examples/potential/forces_with_cate.jl b/examples/potential/forces_with_cate.jl index d9c0fab..f686f0e 100644 --- a/examples/potential/forces_with_cate.jl +++ b/examples/potential/forces_with_cate.jl @@ -63,13 +63,15 @@ ps, st = Lux.setup(MersenneTwister(1234), model) model(X, ps, st) # testing derivative (forces) -g = Zygote.gradient(X -> model(X, ps, st)[1], X)[1] +g = Zygote.gradient(X -> model(X, ps, st)[1], X)[1][1] ## -module Pot +module Pot + using StaticArrays: SVector + import JuLIP, Zygote import JuLIP: cutoff, Atoms import ACEbase: evaluate!, evaluate_d! @@ -79,6 +81,8 @@ module Pot import Optimisers: destructure + get_Z0S(zz0, ZZS) = [SVector{2}(zz0, zzs) for zzs in ZZS] + struct LuxCalc <: JuLIP.SitePotential luxmodel ps @@ -95,12 +99,16 @@ module Pot cutoff(calc::LuxCalc) = calc.rcut function evaluate!(tmp, calc::LuxCalc, Rs, Zs, z0) - E, st = calc.luxmodel(Rs, calc.ps, calc.st) + Z0S = get_Z0S(z0, Zs) + X = (Rs, Z0S) + E, st = calc.luxmodel(X, calc.ps, calc.st) return E[1] end function evaluate_d!(dEs, tmpd, calc::LuxCalc, Rs, Zs, z0) - g = Zygote.gradient(X -> calc.luxmodel(X, calc.ps, calc.st)[1], Rs)[1] + Z0S = get_Z0S(z0, Zs) + X = (Rs, Z0S) + g = Zygote.gradient(X -> calc.luxmodel(X, calc.ps, calc.st)[1], X)[1][1] @assert length(g) == length(Rs) <= length(dEs) dEs[1:length(g)] .= g return dEs @@ -117,7 +125,9 @@ module Pot Js, Rs, Zs = ignore_derivatives() do JuLIP.Potentials.neigsz(nlist, at, i) end - Ei, st = calc.luxmodel(Rs, ps, st) + Z0S = get_Z0S(at.Z[1], Zs) + X = (Rs, Z0S) + Ei, st = calc.luxmodel(X, ps, st) Ei[1] end, 1:length(at) From d370ebcfb0d6baa8fce70d40f9858156dee648fe Mon Sep 17 00:00:00 2001 From: Liwei Zhang <68692847+zhanglw0521@users.noreply.github.com> Date: Sun, 24 Sep 2023 17:14:20 -0700 Subject: [PATCH 13/14] Delete examples/potential/forces_with_cate_loc.jl Redundant file - was pushed mistakenly --- examples/potential/forces_with_cate_loc.jl | 170 --------------------- 1 file changed, 170 deletions(-) delete mode 100644 examples/potential/forces_with_cate_loc.jl diff --git a/examples/potential/forces_with_cate_loc.jl b/examples/potential/forces_with_cate_loc.jl deleted file mode 100644 index cb22bda..0000000 --- a/examples/potential/forces_with_cate_loc.jl +++ /dev/null @@ -1,170 +0,0 @@ -using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote -using Polynomials4ML: LinearLayer, RYlmBasis, lux -using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA -using JuLIP -using Combinatorics: permutations - -rng = Random.MersenneTwister() - -## - -# == configs and form model == -rcut = 5.5 -maxL = 0 -L = 4 -Aspec, AAspec = degord2spec(; totaldegree = 4, - order = 2, - Lmax = 0, ) -cats = AtomicNumber.([:W, :W]) -cat = collect(SVector{2}.(permutations(categories, 2))) - -new_spec = [] -ori_AAspec = deepcopy(AAspec) -new_AAspec = [] - -for bb in ori_AAspec - newbb = [] - for t in bb - push!(newbb, (t..., s = cats)) - end - push!(new_AAspec, newbb) -end - -luxchain, ps, st = equivariant_model(new_AAspec, L; categories=cats) - -## - -# == init example data == - -at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1) -nlist = JuLIP.neighbourlist(at, rcut) -_, Rs, Zs = JuLIP.Potentials.neigsz(nlist, at, 1) -# centere atom -z0 = at.Z[1] - -# serialization, I want the input data structure to lux as simple as possible -get_Z0S(zz0, ZZS) = [SVector{2}(zz0, zzs) for zzs in ZZS] -Z0S = get_Z0S(z0, Zs) - -# input of luxmodel -X = (Rs, Z0S) - -out, st = luxchain(X, ps, st) - - - -# === - - -# testing derivative (forces) -g = Zygote.gradient(X -> model(X, ps, st)[1], X)[1] - -## - -module Pot - import JuLIP, Zygote - import JuLIP: cutoff, Atoms - import ACEbase: evaluate!, evaluate_d! - - import ChainRulesCore - import ChainRulesCore: rrule, ignore_derivatives - - import Optimisers: destructure - - struct LuxCalc <: JuLIP.SitePotential - luxmodel - ps - st - rcut::Float64 - restructure - end - - function LuxCalc(luxmodel, ps, st, rcut) - pvec, rest = destructure(ps) - return LuxCalc(luxmodel, ps, st, rcut, rest) - end - - cutoff(calc::LuxCalc) = calc.rcut - - function evaluate!(tmp, calc::LuxCalc, Rs, Zs, z0) - E, st = calc.luxmodel(Rs, calc.ps, calc.st) - return E[1] - end - - function evaluate_d!(dEs, tmpd, calc::LuxCalc, Rs, Zs, z0) - g = Zygote.gradient(X -> calc.luxmodel(X, calc.ps, calc.st)[1], Rs)[1] - @assert length(g) == length(Rs) <= length(dEs) - dEs[1:length(g)] .= g - return dEs - end - - # ----- parameter estimation stuff - - - function lux_energy(at::Atoms, calc::LuxCalc, ps::NamedTuple, st::NamedTuple) - nlist = ignore_derivatives() do - JuLIP.neighbourlist(at, calc.rcut) - end - return sum( i -> begin - Js, Rs, Zs = ignore_derivatives() do - JuLIP.Potentials.neigsz(nlist, at, i) - end - Ei, st = calc.luxmodel(Rs, ps, st) - Ei[1] - end, - 1:length(at) - ) - end - - # function rrule(::typeof(lux_energy), at::Atoms, calc::LuxCalc, ps::NamedTuple, st::NamedTuple) - # E = lux_energy(at, calc, ps, st) - # function pb(Δ) - # nlist = JuLIP.neighbourlist(at, calc.rcut) - # @show Δ - # error("stop") - # function pb_inner(i) - # Js, Rs, Zs = JuLIP.Potentials.neigsz(nlist, at, i) - # gi = ReverseDiff.gradient() - # end - # for i = 1:length(at) - # Ei, st = calc.luxmodel(Rs, calc.ps, calc.st) - # E += Ei[1] - # end - # end - # end - -end - -## - -using JuLIP -JuLIP.usethreads!(false) -ps.dot.W[:] .= 0.01 * randn(length(ps.dot.W)) - -at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1) -calc = Pot.LuxCalc(model, ps, st, rcut) -JuLIP.energy(calc, at) -JuLIP.forces(calc, at) -JuLIP.virial(calc, at) -Pot.lux_energy(at, calc, ps, st) - -@time JuLIP.energy(calc, at) -@time Pot.lux_energy(at, calc, ps, st) -@time JuLIP.forces(calc, at) - -## - -using Optimisers, ReverseDiff - -p_vec, _rest = destructure(ps) -f(_pvec) = Pot.lux_energy(at, calc, _rest(_pvec), st) - -f(p_vec) -g = Zygote.gradient(f, p_vec)[1] - -@time f(p_vec) -@time Zygote.gradient(f, p_vec)[1] - - -# This fails for now -# gr = ReverseDiff.gradient(f, p_vec)[1] From 681085143e8935c3c59b7a2eb2fa7d036ae85094 Mon Sep 17 00:00:00 2001 From: jerryho Date: Mon, 25 Sep 2023 00:18:10 -0400 Subject: [PATCH 14/14] add test grad and gradp --- examples/potential/test_potential.jl | 105 +++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 examples/potential/test_potential.jl diff --git a/examples/potential/test_potential.jl b/examples/potential/test_potential.jl new file mode 100644 index 0000000..53926af --- /dev/null +++ b/examples/potential/test_potential.jl @@ -0,0 +1,105 @@ +using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote +using Polynomials4ML: LinearLayer, RYlmBasis, lux +using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA +using JuLIP, Combinatorics, Test +using ACEbase.Testing: println_slim, print_tf, fdtest +using Optimisers: destructure +using Printf + +include("staticprod.jl") + +function grad_test2(f, df, X::AbstractVector; verbose = true) + F = f(X) + ∇F = df(X) + nX = length(X) + EE = Matrix(I, (nX, nX)) + + verbose && @printf("---------|----------- \n") + verbose && @printf(" h | error \n") + verbose && @printf("---------|----------- \n") + for h in 0.1.^(-3:9) + gh = [ (f(X + h * EE[:, i]) - F) / h for i = 1:nX ] + verbose && @printf(" %.1e | %.2e \n", h, norm(gh - ∇F, Inf)) + end +end + +rng = Random.MersenneTwister() + +rcut = 5.5 +maxL = 0 +L = 0 +Aspec, AAspec = degord2spec(; totaldegree = 6, + order = 3, + Lmax = 0, ) +cats = AtomicNumber.([:W, :Cu, :Ni, :Fe, :Al]) +ipairs = collect(Combinatorics.permutations(1:length(cats), 2)) +allcats = collect(SVector{2}.(Combinatorics.permutations(cats, 2))) + +for (i, cat) in enumerate(cats) + push!(ipairs, [i, i]) + push!(allcats, SVector{2}([cat, cat])) +end + +new_spec = [] +ori_AAspec = deepcopy(AAspec) +new_AAspec = [] + +for bb in ori_AAspec + newbb = [] + for (t, ip) in zip(bb, ipairs) + push!(newbb, (t..., s = cats[ip])) + end + push!(new_AAspec, newbb) +end + +luxchain, ps, st = equivariant_model(new_AAspec, L; categories=allcats, islong = false) + +at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1) +iCu = [5, 12]; iNi = [3, 8]; iAl = [10]; iFe = [6]; +at.Z[iCu] .= cats[2]; at.Z[iNi] .= cats[3]; at.Z[iAl] .= cats[4]; at.Z[iFe] .= cats[5]; +nlist = JuLIP.neighbourlist(at, rcut) +_, Rs, Zs = JuLIP.Potentials.neigsz(nlist, at, 1) +# centere atom +z0 = at.Z[1] + +# serialization, I want the input data structure to lux as simple as possible +get_Z0S(zz0, ZZS) = [SVector{2}(zz0, zzs) for zzs in ZZS] +Z0S = get_Z0S(z0, Zs) + +# input of luxmodel +X = (Rs, Z0S) + +out, st = luxchain(X, ps, st) + + +# == lux chain eval and grad +B = out + +model = append_layers(luxchain, get1 = WrappedFunction(t -> real.(t)), dot = LinearLayer(length(B), 1), get2 = WrappedFunction(t -> t[1])) + +ps, st = Lux.setup(MersenneTwister(1234), model) + +model(X, ps, st) + +# testing derivative (forces) +g = Zygote.gradient(X -> model(X, ps, st)[1], X)[1] + + +F(Rs) = model((Rs, Z0S), ps, st)[1] +dF(Rs) = Zygote.gradient(rs -> model((rs, Z0S), ps, st)[1], Rs)[1] + +## +@info("test derivative w.r.t X") +print_tf(@test fdtest(F, dF, Rs; verbose=true)) + + +@info("test derivative w.r.t parameter") +p = Zygote.gradient(p -> model(X, p, st)[1], ps)[1] +p, = destructure(p) + +W0, re = destructure(ps) +Fp = w -> model(X, re(w), st)[1] +dFp = w -> ( gl = Zygote.gradient(p -> model(X, p, st)[1], ps)[1]; destructure(gl)[1]) +grad_test2(Fp, dFp, W0) + +