Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
CheukHinHoJerry committed Oct 3, 2023
1 parent b1b2771 commit 12e7207
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 100 deletions.
49 changes: 8 additions & 41 deletions src/vmc/Eloc.jl
Original file line number Diff line number Diff line change
@@ -1,46 +1,14 @@
export SumH
using ACEpsi.AtomicOrbitals: Nuc
using StaticArrays

# INTERFACE FOR HAMILTIANS H ψ -> H(psi, X)
struct SumH{T}
nuclei::Vector{Nuc{T}}
struct SumH{TK, TT, TE}
K::TK
Vext::TT
Vee::TE
end

function Vee(wf, X::Vector{SVector{3, T}}, ps, st) where {T}
nX = length(X)
v = zero(T)
r = zero(T)
@inbounds begin
for i = 1:nX-1
@simd ivdep for j = i+1:nX
r = norm(X[i]-X[j])
v = muladd(1, 1/r, v)
end
end
end
return v
end

function Vext(wf, X::Vector{SVector{3, T}}, nuclei::Vector{Nuc{TT}}, ps, st) where {T, TT}
nX = length(X)
v = zero(T)
r = zero(T)
@inbounds begin
for i = 1:length(nuclei)
@simd ivdep for j = 1:nX
r = norm(nuclei[i].rr - X[j])
v = muladd(nuclei[i].charge, 1/r, v)
end
end
end
return -v
end

K(wf, X::Vector{SVector{3, T}}, ps, st) where {T} = -0.5 * laplacian(wf, X, ps, st)

(H::SumH)(wf, X::Vector{SVector{3, T}}, ps, st) where {T} =
K(wf, X, ps, st) + (Vext(wf, X, H.nuclei, ps, st) + Vee(wf, X, ps, st)) * evaluate(wf, X, ps, st)
(H::SumH)(wf, X::AbstractVector, ps, st) =
H.K(wf, X, ps, st) + (H.Vext(wf, X, ps, st) + H.Vee(wf, X, ps, st)) * evaluate(wf, X, ps, st)


# evaluate local energy with SumH
Expand All @@ -52,7 +20,6 @@ https://arxiv.org/abs/2105.08351

function Elocal(H::SumH, wf, X::AbstractVector, ps, st)
gra = gradient(wf, X, ps, st)
val = Vext(wf, X, H.nuclei, ps, st) + Vee(wf, X, ps, st) - 1/4 * laplacian(wf, X, ps, st) - 1/8 * gra' * gra
val = H.Vext(wf, X, ps, st) + H.Vee(wf, X, ps, st) - 1/4 * laplacian(wf, X, ps, st) - 1/8 * gra' * gra
return val
end

end
110 changes: 51 additions & 59 deletions src/vmc/metropolis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,34 @@ using Distributed: @spawnat, @fetch, @distributed
using SharedArrays

export MHSampler
using ACEpsi.AtomicOrbitals: Nuc
using Lux: Chain

"""
`MHSampler`
Metropolis-Hastings sampling algorithm.
"""
mutable struct MHSampler{T}
Nel::Int64
nuclei::Vector{Nuc{T}}
mutable struct MHSampler
Nel::Int
Δt::Float64 # step size (of Gaussian proposal)
burnin::Int64 # burn-in iterations
lag::Int64 # iterations between successive samples
N_batch::Int64 # batch size
nchains::Int64 # Number of chains
Ψ::Chain # many-body wavefunction for sampling
x0::Vector # initial sampling
burnin::Int # burn-in iterations
lag::Int # iterations between successive samples
N_batch::Int # batch size
nchains::Int # Number of chains
Ψ # many-body wavefunction for sampling
x0::Any # initial sampling
walkerType::String # walker type: "unbiased", "Langevin"
bc::String # boundary condition
d::sampler_dimension # electron dimension
end

MHSampler(Ψ, Nel, nuclei; Δt = 0.1,
MHSampler(Ψ, Nel; Δt = 0.1,
burnin = 100,
lag = 10,
N_batch = 1,
nchains = 1000,
x0 = Vector{Vector{SVector{3, Float64}}}(undef, nchains),
x0 = [],
wT = "unbiased",
bc = "periodic",
type = 1) =
MHSampler(Nel, nuclei, Δt, burnin, lag, N_batch, nchains, Ψ, x0, wT, bc, type)
d = d3()) =
MHSampler(Nel, Δt, burnin, lag, N_batch, nchains, Ψ, x0, wT, bc, d)


"""
Expand All @@ -46,84 +42,81 @@ biased random walk: R_n+1 = R_n + Δ⋅Wn + Δ⋅∇(log Ψ)(R_n)

eval(wf, X::AbstractVector, ps, st) = wf(X, ps, st)[1]

function MHstep(r0::Vector{Vector{SVector{3, TT}}},
Ψx0::Vector{T},
Nels::Int64,
sam::MHSampler, ps::NamedTuple, st::NamedTuple; batch_size = 1) where {T, TT}
rand_sample(X::Vector{SVector{3, TX}}, Nels::Int, Δt::Float64) where {TX}= begin
return X + Δt * randn(SVector{3, TX}, Nels)
end
rp = rand_sample.(r0, Ref(Nels), Ref(sam.Δt))
function MHstep(r0,
Ψx0,
Nels::Int,
sam::MHSampler, ps, st; batch_size = 1)
rp = rand_sample.(r0, Ref(Nels), Ref(sam.Δt), Ref(sam.d))
raw_data = pmap(rp; batch_size = batch_size) do d
sam.Ψ(d, ps, st)[1]
end
Ψxp = vcat(raw_data)
accprob = accfcn(Ψx0, Ψxp)
u = rand(sam.nchains)
acc = u .<= accprob[:]
r::Vector{Vector{SVector{3, TT}}} = acc .* rp + (1.0 .- acc) .* r0
r = acc .* rp + (1.0 .- acc) .* r0
Ψ = acc .* Ψxp + (1.0 .- acc) .* Ψx0
return r, Ψ, acc
end


rand_sample(X::AbstractVector, Nels::Int, Δt::Number, d::d3) = begin
@view(X[rand(1:Nels)]) .+= Δt * randn(SVector{3, eltype(X[1])}, 1)
return X
end

rand_sample(X::AbstractVector, Nels::Int, Δt::Number, d::T) where T <: Union{d1, d1_lattice} = begin
@view(X[rand(1:Nels)]) .+= Δt * randn(1)
return X
end

"""
acceptance rate for log|Ψ|
ψₜ₊₁²/ψₜ² = exp((log|Ψₜ₊₁|^2-log |ψₜ|^2))
"""

function accfcn(Ψx0::Vector{T}, Ψxp::Vector{T}) where {T}
function accfcn(Ψx0, Ψxp)
acc = exp.(Ψxp .- Ψx0)
return acc
end

"""============== Metropolis sampling algorithm ============
type = "restart"
"""
rand_init(Δt::Number, Nel::Int, nchains::Int, d::d3) = [Δt * randn(SVector{3, Float64}, Nel) for _ = 1:nchains]

function pos(sam::MHSampler)
T = eltype(sam.nuclei[1].rr)
M = length(sam.nuclei)
rr = zeros(SVector{3, T}, sam.Nel)
tt = zeros(Int, 1)
@inbounds begin
for i = 1:M
@simd ivdep for j = Int(ceil(sam.nuclei[i].charge))
tt[1] += 1
rr[tt[1]] = sam.nuclei[i].rr
end
end
end
return rr
end
rand_init(Δt::Number, Nel::Int, nchains::Int, d::d1) = [Δt * randn(Nel) for _ = 1:nchains]

# same as d1 rand_init, except shifted by equally spaced lattice
rand_init(Δt::Number, Nel::Int, nchains::Int, d::d1_lattice) = [Δt * randn(Nel) + d.L for _ = 1:nchains]

function sampler_restart(sam::MHSampler, ps, st; batch_size = 1)
r = pos(sam)
T = eltype(r[1])
r0 = sam.x0
r0 = [sam.Δt * randn(SVector{3, T}, sam.Nel) + r for _ = 1:sam.nchains]
r0 = rand_init(sam.Δt, sam.Nel, sam.nchains, sam.d)
Ψx0 = eval.(Ref(sam.Ψ), r0, Ref(ps), Ref(st))
acc = zeros(T, sam.burnin)
for i = 1 : sam.burnin
acc = []
for _ = 1 : sam.burnin
r0, Ψx0, a = MHstep(r0, Ψx0, sam.Nel, sam, ps, st; batch_size = batch_size);
acc[i] = mean(a)
push!(acc,a)
end
return r0, Ψx0, mean(acc)
return r0, Ψx0, mean(mean(acc))
end

"""
type = "continue"
start from the previous sampling x0
"""
function sampler(sam::MHSampler, ps, st; batch_size = 1)
r0 = sam.x0
Ψx0 = eval.(Ref(sam.Ψ), r0, Ref(ps), Ref(st))
T = eltype(r0[1][1])
acc = zeros(T, sam.lag)
if isempty(sam.x0)
r0, Ψx0, = sampler_restart(sam, ps, st; batch_size = batch_size);
else
r0 = sam.x0
Ψx0 = eval.(Ref(sam.Ψ), r0, Ref(ps), Ref(st))
end
acc = []
for i = 1:sam.lag
r0, Ψx0, a = MHstep(r0, Ψx0, sam.Nel, sam, ps, st; batch_size = batch_size);
acc[i] = mean(a)
r0, Ψx0, a = MHstep(r0, Ψx0, sam.Nel, sam, ps, st);
push!(acc, a)
end
return r0, Ψx0, mean(acc)
return r0, Ψx0, mean(mean(acc))
end


Expand All @@ -143,8 +136,7 @@ function Eloc_Exp_TV_clip(wf, ps, st,
sam::MHSampler,
ham::SumH;
clip = 5., batch_size = 1)
x, ~, acc = sampler(sam, ps, st, batch_size = batch_size)
# Eloc = Elocal.(Ref(ham), Ref(wf), x, Ref(ps), Ref(st))
x, ~, acc = sampler(sam, ps, st; batch_size = batch_size)
raw_data = pmap(x; batch_size = batch_size) do d
Elocal(ham, wf, d, ps, st)
end
Expand Down

0 comments on commit 12e7207

Please sign in to comment.