Skip to content

Commit

Permalink
E, F, V together
Browse files Browse the repository at this point in the history
  • Loading branch information
cortner committed Sep 25, 2023
1 parent 1dd831c commit ee91b3c
Showing 1 changed file with 57 additions and 4 deletions.
61 changes: 57 additions & 4 deletions examples/potential/forces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ g = Zygote.gradient(X -> model(X, ps, st)[1], X)[1]
##

module Pot
import JuLIP, Zygote
import JuLIP, Zygote, StaticArrays
import JuLIP: cutoff, Atoms
import ACEbase: evaluate!, evaluate_d!
import StaticArrays: SVector, SMatrix

import ChainRulesCore
import ChainRulesCore: rrule, ignore_derivatives
Expand Down Expand Up @@ -96,6 +97,44 @@ module Pot
)
end


function lux_efv(at::Atoms, calc::LuxCalc, ps::NamedTuple, st::NamedTuple)
nlist = ignore_derivatives() do
JuLIP.neighbourlist(at, calc.rcut)
end
E = 0.0
F = zeros(SVector{3, Float64}, length(at))
V = zero(SMatrix{3, 3, Float64})
for i = 1:length(at)
Js, Rs, Zs = ignore_derivatives() do
JuLIP.Potentials.neigsz(nlist, at, i)
end
comp = Zygote.withgradient(_X -> calc.luxmodel(_X, ps, st)[1], Rs)
Ei = comp.val
∇Ei = comp.grad[1]
# energy
E += Ei

# Forces
for j = 1:length(Rs)
F[Js[j]] -= ∇Ei[j]
F[i] += ∇Ei[j]
end

# Virial
if length(Rs) > 0
V -= sum(∇Eij * Rij' for (∇Eij, Rij) in zip(∇Ei, Rs))
end
end

return E, F, V
end

# site_virial(dV::AbstractVector{JVec{T1}}, R::AbstractVector{JVec{T2}}
# ) where {T1, T2} = (
# length(R) > 0 ? (- sum( dVi * Ri' for (dVi, Ri) in zip(dV, R) ))
# : zero(JMat{fltype_intersect(T1, T2)})
# )
# function rrule(::typeof(lux_energy), at::Atoms, calc::LuxCalc, ps::NamedTuple, st::NamedTuple)
# E = lux_energy(at, calc, ps, st)
# function pb(Δ)
Expand Down Expand Up @@ -140,11 +179,25 @@ p_vec, _rest = destructure(ps)
f(_pvec) = Pot.lux_energy(at, calc, _rest(_pvec), st)

f(p_vec)
g = Zygote.gradient(f, p_vec)[1]
gz = Zygote.gradient(f, p_vec)[1]

@time f(p_vec)
@time Zygote.gradient(f, p_vec)[1]

# We can use either Zygote or ReverseDiff for gradients.
gr = ReverseDiff.gradient(f, p_vec)
@show gr gz

@info("Interestingly ReverseDiff is much faster here, almost optimal")
@time f(p_vec)
@time Zygote.gradient(f, p_vec)[1]
@time ReverseDiff.gradient(f, p_vec)

##

@info("Compute Energies, Forces and Virials at the same time")
E, F, V = Pot.lux_efv(at, calc, ps, st)
@show E JuLIP.energy(calc, at)
@show F JuLIP.forces(calc, at)
@show V JuLIP.virial(calc, at)

# This fails for now
# gr = ReverseDiff.gradient(f, p_vec)[1]

0 comments on commit ee91b3c

Please sign in to comment.