Skip to content

Commit

Permalink
Merge branch 'potential' of github.com:ACEsuit/EquivariantModels.jl i…
Browse files Browse the repository at this point in the history
…nto potential
  • Loading branch information
cortner committed Sep 25, 2023
2 parents ee91b3c + 6810851 commit 9498a15
Show file tree
Hide file tree
Showing 8 changed files with 423 additions and 17 deletions.
52 changes: 52 additions & 0 deletions examples/potential/energy_with_multi_species.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote
using Polynomials4ML: LinearLayer, RYlmBasis, lux
using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA
using JuLIP, Combinatorics

rng = Random.MersenneTwister()

rcut = 5.5
maxL = 0
L = 0
Aspec, AAspec = degord2spec(; totaldegree = 6,
order = 3,
Lmax = 0, )
cats = AtomicNumber.([:W, :Cu, :Ni, :Fe, :Al])
ipairs = collect(Combinatorics.permutations(1:length(cats), 2))
allcats = collect(SVector{2}.(Combinatorics.permutations(cats, 2)))

for (i, cat) in enumerate(cats)
push!(ipairs, [i, i])
push!(allcats, SVector{2}([cat, cat]))
end

new_spec = []
ori_AAspec = deepcopy(AAspec)
new_AAspec = []

for bb in ori_AAspec
newbb = []
for (t, ip) in zip(bb, ipairs)
push!(newbb, (t..., s = cats[ip]))
end
push!(new_AAspec, newbb)
end

luxchain, ps, st = equivariant_model(new_AAspec, L; categories=allcats)

at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1)
iCu = [5, 12]; iNi = [3, 8]; iAl = [10]; iFe = [6];
at.Z[iCu] .= cats[2]; at.Z[iNi] .= cats[3]; at.Z[iAl] .= cats[4]; at.Z[iFe] .= cats[5];
nlist = JuLIP.neighbourlist(at, rcut)
_, Rs, Zs = JuLIP.Potentials.neigsz(nlist, at, 1)
# centere atom
z0 = at.Z[1]

# serialization, I want the input data structure to lux as simple as possible
get_Z0S(zz0, ZZS) = [SVector{2}(zz0, zzs) for zzs in ZZS]
Z0S = get_Z0S(z0, Zs)

# input of luxmodel
X = (Rs, Z0S)

out, st = luxchain(X, ps, st)
188 changes: 188 additions & 0 deletions examples/potential/forces_with_cate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote
using Polynomials4ML: LinearLayer, RYlmBasis, lux
using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA
using JuLIP
using Combinatorics: permutations

rng = Random.MersenneTwister()

include("staticprod.jl")
##

# == configs and form model ==
rcut = 5.5
maxL = 0
L = 4
Aspec, AAspec = degord2spec(; totaldegree = 4,
order = 2,
Lmax = 0, )
cats = AtomicNumber.([:W, :W])

new_spec = []
ori_AAspec = deepcopy(AAspec)
new_AAspec = []

for bb in ori_AAspec
newbb = []
for t in bb
push!(newbb, (t..., s = cats))
end
push!(new_AAspec, newbb)
end

cat_perm2 = collect(SVector{2}.(permutations(cats, 2)))

luxchain, ps, st = equivariant_model(new_AAspec, 0; categories = cat_perm2, islong = false)

##

# == init example data ==

at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1)
nlist = JuLIP.neighbourlist(at, rcut)
_, Rs, Zs = JuLIP.Potentials.neigsz(nlist, at, 1)
# centere atom
z0 = at.Z[1]

# serialization, I want the input data structure to lux as simple as possible
get_Z0S(zz0, ZZS) = [SVector{2}(zz0, zzs) for zzs in ZZS]
Z0S = get_Z0S(z0, Zs)

# input of luxmodel
X = (Rs, Z0S)


# == lux chain eval and grad
out, st = luxchain(X, ps, st)
B = out

model = append_layers(luxchain, get1 = WrappedFunction(t -> real.(t)), dot = LinearLayer(length(B), 1), get2 = WrappedFunction(t -> t[1]))

ps, st = Lux.setup(MersenneTwister(1234), model)

model(X, ps, st)

# testing derivative (forces)
g = Zygote.gradient(X -> model(X, ps, st)[1], X)[1][1]



##

module Pot
using StaticArrays: SVector

import JuLIP, Zygote
import JuLIP: cutoff, Atoms
import ACEbase: evaluate!, evaluate_d!

import ChainRulesCore
import ChainRulesCore: rrule, ignore_derivatives

import Optimisers: destructure

get_Z0S(zz0, ZZS) = [SVector{2}(zz0, zzs) for zzs in ZZS]

struct LuxCalc <: JuLIP.SitePotential
luxmodel
ps
st
rcut::Float64
restructure
end

function LuxCalc(luxmodel, ps, st, rcut)
pvec, rest = destructure(ps)
return LuxCalc(luxmodel, ps, st, rcut, rest)
end

cutoff(calc::LuxCalc) = calc.rcut

function evaluate!(tmp, calc::LuxCalc, Rs, Zs, z0)
Z0S = get_Z0S(z0, Zs)
X = (Rs, Z0S)
E, st = calc.luxmodel(X, calc.ps, calc.st)
return E[1]
end

function evaluate_d!(dEs, tmpd, calc::LuxCalc, Rs, Zs, z0)
Z0S = get_Z0S(z0, Zs)
X = (Rs, Z0S)
g = Zygote.gradient(X -> calc.luxmodel(X, calc.ps, calc.st)[1], X)[1][1]
@assert length(g) == length(Rs) <= length(dEs)
dEs[1:length(g)] .= g
return dEs
end

# ----- parameter estimation stuff


function lux_energy(at::Atoms, calc::LuxCalc, ps::NamedTuple, st::NamedTuple)
nlist = ignore_derivatives() do
JuLIP.neighbourlist(at, calc.rcut)
end
return sum( i -> begin
Js, Rs, Zs = ignore_derivatives() do
JuLIP.Potentials.neigsz(nlist, at, i)
end
Z0S = get_Z0S(at.Z[1], Zs)
X = (Rs, Z0S)
Ei, st = calc.luxmodel(X, ps, st)
Ei[1]
end,
1:length(at)
)
end

# function rrule(::typeof(lux_energy), at::Atoms, calc::LuxCalc, ps::NamedTuple, st::NamedTuple)
# E = lux_energy(at, calc, ps, st)
# function pb(Δ)
# nlist = JuLIP.neighbourlist(at, calc.rcut)
# @show Δ
# error("stop")
# function pb_inner(i)
# Js, Rs, Zs = JuLIP.Potentials.neigsz(nlist, at, i)
# gi = ReverseDiff.gradient()
# end
# for i = 1:length(at)
# Ei, st = calc.luxmodel(Rs, calc.ps, calc.st)
# E += Ei[1]
# end
# end
# end

end

##

using JuLIP
JuLIP.usethreads!(false)
ps.dot.W[:] .= 0.01 * randn(length(ps.dot.W))

at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1)
calc = Pot.LuxCalc(model, ps, st, rcut)
JuLIP.energy(calc, at)
JuLIP.forces(calc, at)
JuLIP.virial(calc, at)
Pot.lux_energy(at, calc, ps, st)

@time JuLIP.energy(calc, at)
@time Pot.lux_energy(at, calc, ps, st)
@time JuLIP.forces(calc, at)

##

using Optimisers, ReverseDiff

p_vec, _rest = destructure(ps)
f(_pvec) = Pot.lux_energy(at, calc, _rest(_pvec), st)

f(p_vec)
g = Zygote.gradient(f, p_vec)[1]

@time f(p_vec)
@time Zygote.gradient(f, p_vec)[1]


# This fails for now
# gr = ReverseDiff.gradient(f, p_vec)[1]
25 changes: 25 additions & 0 deletions examples/potential/staticprod.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import Polynomials4ML: _static_prod_ed, _pb_grad_static_prod

function _static_prod_ed(b::NTuple{N, Any}) where N
b2 = b[2:N]
p2, g2 = _static_prod_ed(b2)
return b[1] * p2, tuple(p2, ntuple(i -> b[1] * g2[i], N-1)...)
end

function _static_prod_ed(b::NTuple{1, Any})
return b[1], (one(T),)
end

function _pb_grad_static_prod(∂::NTuple{N, Any}, b::NTuple{N, Any}) where N
∂2 = ∂[2:N]
b2 = b[2:N]
p2, g2, u2 = _pb_grad_static_prod(∂2, b2)
return b[1] * p2,
tuple(p2, ntuple(i -> b[1] * g2[i], N-1)...),
tuple(sum(∂2 .* g2), ntuple(i -> ∂[1] * g2[i] + b[1] * u2[i], N-1)...)
end

function _pb_grad_static_prod(∂::NTuple{1, Any}, b::NTuple{1, Any})
return b[1], (one(T),), (zero(T),)
end

105 changes: 105 additions & 0 deletions examples/potential/test_potential.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote
using Polynomials4ML: LinearLayer, RYlmBasis, lux
using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA
using JuLIP, Combinatorics, Test
using ACEbase.Testing: println_slim, print_tf, fdtest
using Optimisers: destructure
using Printf

include("staticprod.jl")

function grad_test2(f, df, X::AbstractVector; verbose = true)
F = f(X)
∇F = df(X)
nX = length(X)
EE = Matrix(I, (nX, nX))

verbose && @printf("---------|----------- \n")
verbose && @printf(" h | error \n")
verbose && @printf("---------|----------- \n")
for h in 0.1.^(-3:9)
gh = [ (f(X + h * EE[:, i]) - F) / h for i = 1:nX ]
verbose && @printf(" %.1e | %.2e \n", h, norm(gh - ∇F, Inf))
end
end

rng = Random.MersenneTwister()

rcut = 5.5
maxL = 0
L = 0
Aspec, AAspec = degord2spec(; totaldegree = 6,
order = 3,
Lmax = 0, )
cats = AtomicNumber.([:W, :Cu, :Ni, :Fe, :Al])
ipairs = collect(Combinatorics.permutations(1:length(cats), 2))
allcats = collect(SVector{2}.(Combinatorics.permutations(cats, 2)))

for (i, cat) in enumerate(cats)
push!(ipairs, [i, i])
push!(allcats, SVector{2}([cat, cat]))
end

new_spec = []
ori_AAspec = deepcopy(AAspec)
new_AAspec = []

for bb in ori_AAspec
newbb = []
for (t, ip) in zip(bb, ipairs)
push!(newbb, (t..., s = cats[ip]))
end
push!(new_AAspec, newbb)
end

luxchain, ps, st = equivariant_model(new_AAspec, L; categories=allcats, islong = false)

at = rattle!(bulk(:W, cubic=true, pbc=true) * 2, 0.1)
iCu = [5, 12]; iNi = [3, 8]; iAl = [10]; iFe = [6];
at.Z[iCu] .= cats[2]; at.Z[iNi] .= cats[3]; at.Z[iAl] .= cats[4]; at.Z[iFe] .= cats[5];
nlist = JuLIP.neighbourlist(at, rcut)
_, Rs, Zs = JuLIP.Potentials.neigsz(nlist, at, 1)
# centere atom
z0 = at.Z[1]

# serialization, I want the input data structure to lux as simple as possible
get_Z0S(zz0, ZZS) = [SVector{2}(zz0, zzs) for zzs in ZZS]
Z0S = get_Z0S(z0, Zs)

# input of luxmodel
X = (Rs, Z0S)

out, st = luxchain(X, ps, st)


# == lux chain eval and grad
B = out

model = append_layers(luxchain, get1 = WrappedFunction(t -> real.(t)), dot = LinearLayer(length(B), 1), get2 = WrappedFunction(t -> t[1]))

ps, st = Lux.setup(MersenneTwister(1234), model)

model(X, ps, st)

# testing derivative (forces)
g = Zygote.gradient(X -> model(X, ps, st)[1], X)[1]


F(Rs) = model((Rs, Z0S), ps, st)[1]
dF(Rs) = Zygote.gradient(rs -> model((rs, Z0S), ps, st)[1], Rs)[1]

##
@info("test derivative w.r.t X")
print_tf(@test fdtest(F, dF, Rs; verbose=true))


@info("test derivative w.r.t parameter")
p = Zygote.gradient(p -> model(X, p, st)[1], ps)[1]
p, = destructure(p)

W0, re = destructure(ps)
Fp = w -> model(X, re(w), st)[1]
dFp = w -> ( gl = Zygote.gradient(p -> model(X, p, st)[1], ps)[1]; destructure(gl)[1])
grad_test2(Fp, dFp, W0)


Loading

0 comments on commit 9498a15

Please sign in to comment.