Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rq draft #23

Merged
merged 15 commits into from
Nov 8, 2023
11 changes: 2 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,18 @@ ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HyperDualNumbers = "50ceba7f-c3ee-5a84-a6e8-3ad40456ec97"
<<<<<<< HEAD
JLD = "4138dd39-2aa7-5051-a626-17a0bb65d9c8"
=======
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
>>>>>>> parallel_mc
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
ObjectPools = "658cac36-ff0f-48ad-967c-110375d98c9d"
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ParallelDataTransfer = "2dcacdae-9679-587a-88bb-8b444fb7085b"
Polynomials4ML = "03c4bcba-a943-47e9-bfa1-b1661fc2974f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -41,7 +35,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ACEbase = "0.4.2"
BenchmarkTools = "1"
ForwardDiff = "0.10"
JSON = "0.21"
ObjectPools = "0.3.1"
julia = "1"

Expand Down
4 changes: 2 additions & 2 deletions src/ACEpsi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ include("hyper.jl")
include("spins.jl")

# the old 1d backflow code, keep around for now...
# include("envelope.jl")
# include("bflow.jl")
include("envelope.jl")
include("bflow.jl")

# the new 3d backflow code
include("atomicorbitals/atomicorbitals.jl")
Expand Down
14 changes: 8 additions & 6 deletions src/bflow1d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ function get_spec(spec1p)
return spec[:]
end

function BFwf1d_lux(Nel::Integer, Pn::OrthPolyBasis1D3T; totdeg = 15,
function BFwf1d_lux(Nel::Integer, Pn; totdeg = 15,
ν = 3, T = Float64,
sd_admissible = bb -> prod(b.s != '∅' for b in bb) == 0)
sd_admissible = bb -> prod(b.s != '∅' for b in bb) == 0,
trans = identity)

spec1p = [(n = n) for n = 1:totdeg]

l_trans = WrappedFunction(x -> trans.(x))
l_Pn = Polynomials4ML.lux(Pn)
# ----------- Lux connections ---------
# BackFlowPooling: (length(nuclei), nX, length(spec1 from totaldegree)) -> (nX, 3, length(nuclei), length(spec1))
Expand Down Expand Up @@ -63,9 +65,9 @@ function BFwf1d_lux(Nel::Integer, Pn::OrthPolyBasis1D3T; totdeg = 15,
reshape_func = x -> reshape(x, (size(x, 1), prod(size(x)[2:end])))

_det = x -> size(x) == (1, 1) ? x[1,1] : det(Matrix(x))
BFwf_chain = Chain(; Pn = l_Pn, bA = pooling_layer, reshape = WrappedFunction(reshape_func),
bAA = corr_layer, hidden1 = DenseLayer(Nel, length(corr1)),
Mask = ACEpsi.MaskLayer(Nel), det = WrappedFunction(x -> _det(x)), prod = WrappedFunction(x -> prod(x)), logabs = WrappedFunction(x -> 2 * log(abs(x))))
return BFwf_chain
BFwf_chain = Chain(; trans = l_trans, Pn = l_Pn, bA = pooling_layer, reshape = WrappedFunction(reshape_func),
bAA = corr_layer, hidden1 = LinearLayer(length(corr1), Nel),
Mask = ACEpsi.MaskLayer(Nel), det = WrappedFunction(x -> (_det(x))), prod = WrappedFunction(x -> prod(x)), logabs = WrappedFunction(x -> 2 * log(abs(x))))
return BFwf_chain, spec, spec1p
end

207 changes: 165 additions & 42 deletions src/bflow1dps.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
using Polynomials4ML, Random
# using Polynomials4ML: OrthPolyBasis1D3T
using Polynomials4ML: AbstractPoly4MLBasis, PooledSparseProduct, SparseSymmProdDAG, SparseSymmProd, LinearLayer
using ObjectPools: release!
using Polynomials4ML.Utils: gensparse
using LinearAlgebra: qr, I, logabsdet, pinv, mul!, dot , tr, det
import ForwardDiff
using ACEpsi.AtomicOrbitals: make_nlms_spec
using ACEpsi: ↑, ↓, ∅, spins, extspins, Spin, spin2idx, idx2spin
using ACEpsi: ↑, ↓, ∅, spins, extspins, Spin, spin2idx, idx2spin, JCasino1dVb, JCasinoChain
using ACEpsi
using LuxCore: AbstractExplicitLayer
using LuxCore
Expand All @@ -15,8 +13,10 @@ using Lux: Chain, WrappedFunction, BranchLayer
using ChainRulesCore
using ChainRulesCore: NoTangent
using Zygote
using StaticArrays: SA

""" Embed by displacement
"""
Embed by displacement
"""
function embed_diff_func(Xt, i)
T = eltype(Xt)
Expand All @@ -29,7 +29,8 @@ function embed_diff_func(Xt, i)
return copy(Xts)
end

""" trivial embedding
"""
trivial embedding. This should be removed later.
"""
function embed_usual_func(Xt, i)
T = eltype(Xt)
Expand All @@ -41,15 +42,37 @@ function embed_usual_func(Xt, i)
return copy(Xts)
end

function get_spec_Wigner(spec1p)
spec = []
spin = SA['∅','σ']

spec = Array{Any}(undef, (2, length(spec1p)))

for (k, n) in enumerate(spec1p)
for (is, s) in enumerate(spin)
spec[is, k] = (s=s, n)
end
end

return spec[:]
end


"""
According to manuscript, the bf orbital ϕ(x1;x2,…,xN) is (partially) symmetric in x2,…,xN.
Thus admissible specs should be subset of (N×Z3) × (N×Z3)_ord^(B-1) (Note Z3={↑,↓,∅}).
This version of BF generates such specs. Previous version generates subset of (N×Z3)_ord^B.

function BFwf1dps_lux(Nel::Integer, Pn::AbstractPoly4MLBasis; totdeg = length(Pn),
Future: In principle, order 1 orbital/discretization can be different from orbtials defining pooled basis A
"""
function BFwfTrig_lux(Nel::Integer, Pn::AbstractPoly4MLBasis; totdeg = length(Pn),
ν = 3, T = Float64, trans = x -> x,
sd_admissible = bb -> prod(b.s != '∅' for b in bb) == 0)

# create as much as we can first, and then filter later
spec1p = [(n = n) for n = 1:totdeg]

K = length(Pn) # mathcing ACESchrodinger
spec1p = [(n = n) for n = 1:K]


l_trans = Lux.WrappedFunction(x -> trans.(x))
l_Pn = Polynomials4ML.lux(Pn)
# ----------- Lux connections ---------
Expand All @@ -58,18 +81,27 @@ function BFwf1dps_lux(Nel::Integer, Pn::AbstractPoly4MLBasis; totdeg = length(Pn
pooling_layer = ACEpsi.lux(pooling)

spec1p = get_spec(spec1p)
# define sparse for n-correlations
tup2b = vv -> [ spec1p[v] for v in vv[vv .> 0] ]
default_admissible = bb -> (length(bb) == 0) || (sum(b.n - 1 for b in bb ) <= totdeg)

specAA = gensparse(; NU = ν, tup2b = tup2b, admissible = default_admissible,
minvv = fill(0, ν),
maxvv = fill(length(spec1p), ν),
ordered = true)
spec = [ vv[vv .> 0] for vv in specAA if !(isempty(vv[vv .> 0]))]

# further restrict
spec = [t for t in spec if sd_admissible([spec1p[t[j]] for j = 1:length(t)])]
spec = [[i] for i in eachindex(spec1p)] # spec of order 1
spec = [t for t in spec if spec1p[t[1]].s == '∅'] # body order 1 term should be ∅

if ν > 1
# define sparse for (n-1)-correlations for order ≥ 2 terms
tup2b = vv -> [ spec1p[v] for v in vv[vv .> 0] ]
default_admissible = bb -> (length(bb) == 0) || (sum(ceil((b.n - 1)/2) for b in bb ) <= ceil((totdeg+ν)/2)) # totdeg>1 is P4ML (unnatural) index for Rtrig

specAA = gensparse(; NU = ν-1, tup2b = tup2b, admissible = default_admissible,
minvv = fill(0, ν-1),
maxvv = fill(length(spec1p), ν-1),
ordered = true) ## gensparse automatically order the spec (it assumes S_N symmetry)
spec_bf = [ vv[vv .> 0] for vv in specAA if !(isempty(vv[vv .> 0]))]

# combining with order 1 orital
spec_bf = [cat(b, t; dims=1) for b in spec for t in spec_bf]
spec = cat(spec, spec_bf; dims=1) # this is the old spec format, except now each b∈spec only has b[2]≤…≤b[B] ordered

# further restrict
spec = [t for t in spec if sd_admissible([spec1p[t[j]] for j = 1:length(t)])]
end

# define n-correlation
corr1 = Polynomials4ML.SparseSymmProd(spec)
Expand All @@ -79,29 +111,56 @@ function BFwf1dps_lux(Nel::Integer, Pn::AbstractPoly4MLBasis; totdeg = length(Pn

reshape_func = x -> reshape(x, (size(x, 1), prod(size(x)[2:end])))

embed_layers = Tuple(collect(Lux.WrappedFunction(x -> embed_usual_func(x, i)) for i = 1:Nel))
embed_layers = Tuple(collect(Lux.WrappedFunction(x -> embed_diff_func(x, i)) for i = 1:Nel))
l_Pns = Tuple(collect(l_Pn for _ = 1:Nel))


_det = x -> size(x) == (1, 1) ? x[1,1] : det(Matrix(x))

BFwf_chain = Chain(; trans = l_trans, diff = Lux.BranchLayer(embed_layers...), Pn = Lux.Parallel(nothing, l_Pns...), bA = pooling_layer, reshape = WrappedFunction(reshape_func), bAA = corr_layer, hidden1 = LinearLayer(length(corr1), Nel),
BFwf_chain = Chain(; trans = l_trans, diff = Lux.BranchLayer(embed_layers...), Pn = Lux.Parallel(nothing, l_Pns...), bA = pooling_layer, reshape = WrappedFunction(reshape_func), bAA = corr_layer, hidden1 = LinearLayer(length(corr1), Nel), # hidden1 = ACEpsi.DenseLayer(Nel, length(corr1)),
Mask = ACEpsi.MaskLayer(Nel), det = WrappedFunction(x -> _det(x)), logabs = WrappedFunction(x -> 2 * log(abs(x))))
return BFwf_chain, spec, spec1p
end

"""According to manuscript, the bf orbital ϕ(x1;x2,…,xN) is (partially) symmetric in x2,…,xN.
Thus admissible specs should be subset of (N×Z3) × (N×Z3)_ord^(B-1) (Note Z3={↑,↓,∅}).
This version of BF generates such specs. Previous version generates subset of (N×Z3)_ord^B.
"""
WignerLayer
Extra layer for basis sparsification according to Wigner Ansatz
(REF)
"""
struct WignerLayer <: AbstractExplicitLayer end

(l::WignerLayer)(A::AbstractMatrix{T}, ps, st) where T = begin
N = length(st.Σ)
@assert all( t == '↑' for t in st.Σ[1:Int(N / 2)])
@assert all( t == '↓' for t in st.Σ[Int(N / 2) + 1:N])

# size(A,2) = spatial_basis x 3

K = Int(size(A, 2) * 2 / 3)
ATilde = Zygote.Buffer(zeros(T, (N, K)))

# ∅ spin
ATilde[:, 1:2:K] = A[:, 3:3:end]

# non-empty spin (↓) for first N/2 electrons
ATilde[1:Int(N/2), 2:2:K] = A[1:Int(N/2), 2:3:end]

# non-empty spin (↑) for last N/2 electrons
ATilde[Int(N/2)+1:N, 2:2:K] = A[Int(N/2)+1:N, 1:3:end]

return copy(ATilde), st
end

Future: In principle, order 1 orbital can be different from orbtials defining pooled basis A
"""
function BFwf1dps_lux2(Nel::Integer, Pn::AbstractPoly4MLBasis; totdeg = length(Pn),
WIGwfTrig_lux
BFwf Lux chain according to Wigner Ansatz
(REF)
"""
function WIGwfTrig_lux(Nel::Integer, Pn::AbstractPoly4MLBasis; totdeg = length(Pn),
ν = 3, T = Float64, trans = x -> x,
sd_admissible = bb -> prod(b.s != '∅' for b in bb) == 0)

# create as much as we can first, and then filter later
spec1p = [(n = n) for n = 1:length(Pn)]
spec1p = [(n = n) for n = 1:totdeg]


l_trans = Lux.WrappedFunction(x -> trans.(x))
Expand All @@ -110,27 +169,31 @@ function BFwf1dps_lux2(Nel::Integer, Pn::AbstractPoly4MLBasis; totdeg = length(P
# BackFlowPooling: (length(nuclei), nX, length(spec1 from totaldegree)) -> (nX, 3, length(nuclei), length(spec1))
pooling = ACEpsi.BackflowPooling1dps()
pooling_layer = ACEpsi.lux(pooling)

spec1p = get_spec(spec1p)
spec = [[i] for i in eachindex(spec1p)] # spec of order 1

# Wigner spec1p
spec1p_winger = get_spec_Wigner(spec1p)

# initalize spec for n-corr
spec = [[i] for i in eachindex(spec1p_winger)] # spec of order 1
spec = [t for t in spec if spec1p_winger[t[1]].s == '∅'] # body order 1 term should be ∅

if ν > 1
# define sparse for (n-1)-correlations for order ≥ 2 terms
tup2b = vv -> [ spec1p[v] for v in vv[vv .> 0] ]
default_admissible = bb -> (length(bb) == 0) || (sum(b.n - 1 for b in bb ) <= totdeg)
tup2b = vv -> [ spec1p_winger[v] for v in vv[vv .> 0] ]
default_admissible = bb -> (length(bb) == 0) || (sum(ceil((b.n - 1)/2) for b in bb ) <= ceil((totdeg+ν)/2)) # totdeg>1 is P4ML (unnatural) index for Rtrig

specAA = gensparse(; NU = ν-1, tup2b = tup2b, admissible = default_admissible,
minvv = fill(0, ν-1),
maxvv = fill(length(spec1p), ν-1),
ordered = true)
maxvv = fill(length(spec1p_winger), ν-1),
ordered = true) ## gensparse automatically order the spec (it assumes S_N symmetry)
spec_bf = [ vv[vv .> 0] for vv in specAA if !(isempty(vv[vv .> 0]))]

# combining with order 1 orital
spec_bf = [cat(b, t; dims=1) for b in spec for t in spec_bf]
spec = cat(spec, spec_bf; dims=1) # this is the old spec format, except now each b∈spec only has b[2]≤…≤b[B] ordered
#spec= cat(spec, spec_bf; dims=1) # this is the old spec format, except now each b∈spec only has b[2]≤…≤b[B] ordered

# further restrict
spec = [t for t in spec if sd_admissible([spec1p[t[j]] for j = 1:length(t)])]
spec = [t for t in spec_bf if sd_admissible([spec1p_winger[t[j]] for j = 1:length(t)])]
end

# define n-correlation
Expand All @@ -144,10 +207,70 @@ function BFwf1dps_lux2(Nel::Integer, Pn::AbstractPoly4MLBasis; totdeg = length(P
embed_layers = Tuple(collect(Lux.WrappedFunction(x -> embed_diff_func(x, i)) for i = 1:Nel))
l_Pns = Tuple(collect(l_Pn for _ = 1:Nel))


_det = x -> size(x) == (1, 1) ? x[1,1] : det(Matrix(x))

BFwf_chain = Chain(; trans = l_trans, diff = Lux.BranchLayer(embed_layers...), Pn = Lux.Parallel(nothing, l_Pns...), bA = pooling_layer, reshape = WrappedFunction(reshape_func), bAA = corr_layer, hidden1 = ACEpsi.DenseLayer(Nel, length(corr1)),
BFwf_chain = Chain(; trans = l_trans, diff = Lux.BranchLayer(embed_layers...), Pn = Lux.Parallel(nothing, l_Pns...), bA = pooling_layer, reshape = WrappedFunction(reshape_func), Wigner = WignerLayer(), bAA = corr_layer, hidden1 = LinearLayer(length(corr1), Nel), # hidden1 = ACEpsi.DenseLayer(Nel, length(corr1)),
Mask = ACEpsi.MaskLayer(Nel), det = WrappedFunction(x -> _det(x)), logabs = WrappedFunction(x -> 2 * log(abs(x))))
return BFwf_chain, spec, spec1p
return BFwf_chain, spec, spec1p_winger
end

# === JS factor WIP ===
# function BFJwfTrig_lux(Nel::Integer, Pn::AbstractPoly4MLBasis, J::JCasino1dVb; totdeg = length(Pn),
# ν = 3, T = Float64, trans = x -> x,
# sd_admissible = bb -> prod(b.s != '∅' for b in bb) == 0)

# # create as much as we can first, and then filter later
# spec1p = [(n = n) for n = 1:totdeg]


# l_trans = Lux.WrappedFunction(x -> trans.(x))
# l_Pn = Polynomials4ML.lux(Pn)
# # ----------- Lux connections ---------
# # BackFlowPooling: (length(nuclei), nX, length(spec1 from totaldegree)) -> (nX, 3, length(nuclei), length(spec1))
# pooling = ACEpsi.BackflowPooling1dps()
# pooling_layer = ACEpsi.lux(pooling)

# spec1p = get_spec(spec1p)
# spec = [[i] for i in eachindex(spec1p)] # spec of order 1
# spec = [t for t in spec if spec1p[t[1]].s == '∅'] # body order 1 term should be ∅

# if ν > 1
# # define sparse for (n-1)-correlations for order ≥ 2 terms
# tup2b = vv -> [ spec1p[v] for v in vv[vv .> 0] ]
# default_admissible = bb -> (length(bb) == 0) || (sum(ceil((b.n - 1)/2) for b in bb ) <= ceil((totdeg+ν)/2)) # totdeg>1 is P4ML (unnatural) index for Rtrig

# specAA = gensparse(; NU = ν-1, tup2b = tup2b, admissible = default_admissible,
# minvv = fill(0, ν-1),
# maxvv = fill(length(spec1p), ν-1),
# ordered = true) ## gensparse automatically order the spec (it assumes S_N symmetry)
# spec_bf = [ vv[vv .> 0] for vv in specAA if !(isempty(vv[vv .> 0]))]

# # combining with order 1 orital
# spec_bf = [cat(b, t; dims=1) for b in spec for t in spec_bf]
# spec = cat(spec, spec_bf; dims=1) # this is the old spec format, except now each b∈spec only has b[2]≤…≤b[B] ordered

# # further restrict
# spec = [t for t in spec if sd_admissible([spec1p[t[j]] for j = 1:length(t)])]
# end

# # define n-correlation
# corr1 = Polynomials4ML.SparseSymmProd(spec)

# # (nX, 3, length(nuclei), length(spec1 from totaldegree)) -> (nX, length(spec))
# corr_layer = Polynomials4ML.lux(corr1; use_cache = false)

# reshape_func = x -> reshape(x, (size(x, 1), prod(size(x)[2:end])))

# embed_layers = Tuple(collect(Lux.WrappedFunction(x -> embed_diff_func(x, i)) for i = 1:Nel))
# l_Pns = Tuple(collect(l_Pn for _ = 1:Nel))


# _det = x -> size(x) == (1, 1) ? x[1,1] : det(Matrix(x))

# BFwf_chain = Chain(; Pn = Lux.Parallel(nothing, l_Pns...), bA = pooling_layer, reshape = WrappedFunction(reshape_func), bAA = corr_layer, hidden1 = LinearLayer(length(corr1), Nel), # hidden1 = ACEpsi.DenseLayer(Nel, length(corr1)),
# Mask = ACEpsi.MaskLayer(Nel), det = WrappedFunction(x -> _det(x)), logabs = WrappedFunction(x -> 2 * log(abs(x))))
# Jastrow_chain = JCasinoChain(J)
# BFJwf = Chain(; trans = l_trans, diff = Lux.BranchLayer(embed_layers...), to_be_prod = Lux.BranchLayer(BFwf_chain, Jastrow_chain), Sum = WrappedFunction(x -> x[1] + x[2][1])) # I don't know why Jastrow x[2] is still a vector (size1) here

# return BFJwf, spec, spec1p
# end
Loading
Loading