diff --git a/Project.toml b/Project.toml index a494c1a..985d165 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,9 @@ version = "0.0.2" ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e" BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" diff --git a/src/ConstLinearLayer.jl b/src/ConstLinearLayer.jl new file mode 100644 index 0000000..52ee758 --- /dev/null +++ b/src/ConstLinearLayer.jl @@ -0,0 +1,53 @@ +import ChainRulesCore: rrule +using LuxCore, LinearOperators +using LuxCore: AbstractExplicitLayer + +struct ConstLinearLayer{T} <: AbstractExplicitLayer + op::T +end + +(l::ConstLinearLayer{T})(x::AbstractVector) where T = l.op * x + +(l::ConstLinearLayer{T})(x::AbstractMatrix) where T = begin + Tmp = l(x[1,:]) + for i = 2:size(x,1) + Tmp = [Tmp l(x[i,:])] + end + return Tmp' + end + + (l::ConstLinearLayer)(x::AbstractArray,ps,st) = (l(x), st) + + # NOTE: the following rrule is kept because there is a issue with SparseArray +function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractVector) + val = l(x) + function pb(A) + return NoTangent(), NoTangent(), l.op' * A[1], (op = A[1] * x',), NoTangent() + end + return val, pb +end + +function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractArray,ps,st) + val = l(x,ps,st) + function pb(A) + return NoTangent(), NoTangent(), l.op' * A[1], (op = A[1] * x',), NoTangent() + end + return val, pb +end + +function _linear_operator_L(L, C, pos, len) + if L == 0 + T = ComplexF64 + fL = let C=C, idx=pos#, T=T + (res, aa) -> mul!(res, C, aa[idx]);# try; mul!(res, C, aa[idx]); catch; mul!(zeros(T,size(C,1)), C, aa[idx]); end + end + else + T = SVector{2L+1,ComplexF64} + fL = let C=C, idx=pos#, T=T + (res, aa) -> begin + res[:] .= C * aa[idx] + end + end + end + return LinearOperator{T}(size(C,1), len, false, false, fL, nothing, nothing; S = Vector{T}) + end \ No newline at end of file diff --git a/src/builder.jl b/src/builder.jl index 1ecc941..ba21e6d 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -1,6 +1,6 @@ using LinearAlgebra using SparseArrays: SparseMatrixCSC, sparse -using RepLieGroups.O3: Rot3DCoeffs, Rot3DCoeffs_real, Rot3DCoeffs_long, re_basis, SYYVector, mm_filter +using RepLieGroups.O3: Rot3DCoeffs, Rot3DCoeffs_real, Rot3DCoeffs_long, re_basis, SYYVector, mm_filter, coco_dot using Polynomials4ML: legendre_basis, RYlmBasis, natural_indices, degree using Polynomials4ML.Utils: gensparse using Lux: WrappedFunction @@ -8,6 +8,7 @@ using Lux using Random using Polynomials4ML using StaticArrays +using Combinatorics export equivariant_model, equivariant_SYY_model, equivariant_luxchain_constructor, equivariant_luxchain_constructor_new @@ -23,6 +24,33 @@ _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T},Rot3DCoeffs_real{L,T},Rot3DCoeffs_l spec::Vector{Vector{NamedTuple}}) Return a sparse matrix for symmetrisation of AA basis of spec with equivariance specified by cgen """ +function rpe_basis(A::Union{Rot3DCoeffs,Rot3DCoeffs_long,Rot3DCoeffs_real}, nn::SVector{N, TN}, ll::SVector{N, Int}) where {N, TN} + Ure, Mre = re_basis(A, ll) + G = _gramian(nn, ll, Ure, Mre) + S = svd(G) + rk = rank(Diagonal(S.S); rtol = 1e-7) + Urpe = S.U[:, 1:rk]' + return Diagonal(sqrt.(S.S[1:rk])) * Urpe * Ure, Mre +end + + +function _gramian(nn, ll, Ure, Mre) + N = length(nn) + nre = size(Ure, 1) + G = zeros(Complex{Float64}, nre, nre) + for σ in permutations(1:N) + if (nn[σ] != nn) || (ll[σ] != ll); continue; end + for (iU1, mm1) in enumerate(Mre), (iU2, mm2) in enumerate(Mre) + if mm1[σ] == mm2 + for i1 = 1:nre, i2 = 1:nre + G[i1, i2] += coco_dot(Ure[i1, iU1], Ure[i2, iU2]) + end + end + end + end + return G +end + function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Rot3DCoeffs_long{L,T}}, spec) where {L,T} # allocate triplet format Irow, Jcol = Int[], Int[] @@ -55,7 +83,7 @@ function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Ro if (nn,ll,ss) in nnllset; continue; end # get the Mll indices and coeffs - U, Mll = re_basis(cgen, ll) + U, Mll = rpe_basis(cgen, nn, ll) # conver the Mlls into basis functions (NamedTuples) rpibs = [_nlms2b(nn, ll, mm, ss) for mm in Mll] @@ -83,7 +111,8 @@ function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Ro if (nn,ll) in nnllset; continue; end # get the Mll indices and coeffs - U, Mll = re_basis(cgen, ll) + # U, Mll = re_basis(cgen, ll) + U, Mll = rpe_basis(cgen, nn, ll) # conver the Mlls into basis functions (NamedTuples) rpibs = [_nlms2b(nn, ll, mm) for mm in Mll] @@ -196,6 +225,9 @@ L : Largest equivariance level categories : A list of categories radial_basis : specified radial basis, default using P4ML.legendre_basis """ + +include("ConstLinearLayer.jl") + function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true, rSH = false) if rSH && L > 0 error("rSH is only implemented (for now) for L = 0") @@ -228,9 +260,7 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories= cgen = rSH ? Rot3DCoeffs_real(L) : Rot3DCoeffs(L) # TODO: this should be made group related C = _rpi_A2B_matrix(cgen, spec_nlm) end - - l_sym = islong ? Lux.Parallel(nothing, [WrappedFunction(x -> C[i] * x[pos[i]]) for i = 1:L+1]... ) : WrappedFunction(x -> C * x) - # TODO: make it a Const_LinearLayer instead + l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(_linear_operator_L(i-1,C[i],pos[i],length(spec_nlm))) for i = 1:L+1]... ) : ConstLinearLayer(C) # C - A2Bmap luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB) # luxchain = Chain(xx2AA = luxchain_tmp, BB = l_sym) @@ -271,7 +301,7 @@ function equivariant_SYY_model(spec_nlm, radial::Radial_basis, L::Int64; categor cgen = Rot3DCoeffs_long(L) # TODO: this should be made group related C = _rpi_A2B_matrix(cgen, spec_nlm) - l_sym = WrappedFunction(x -> C * x) + l_sym = ConstLinearLayer(C) # C - A2Bmap luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB) diff --git a/src/radial.jl b/src/radial.jl index ecee4db..e305349 100644 --- a/src/radial.jl +++ b/src/radial.jl @@ -22,7 +22,7 @@ Radial_basis(Rnl::AbstractExplicitLayer) = end # it is in its current form just for the purpose of testing - a more specific example can be found in forces.jl -function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=r->1,f_trans::Function=r->1; spec = nothing) +function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=r->1,f_trans::Function=r->r; spec = nothing) if isnothing(spec) try spec = natural_indices(basis) @@ -30,8 +30,6 @@ function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=r->1,f_tr error("The specification of this Radial_basis should be given explicitly!") end end - - f(r) = f_trans(r) * f_cut(r) - return Radial_basis(Chain(getnorm = WrappedFunction(x -> norm.(x)), trans = WrappedFunction(x -> f.(x)), poly = lux(basis), ), spec) + return Radial_basis(Chain(trans = WrappedFunction(x -> f_trans.(norm.(x))), evaluation = Lux.BranchLayer(poly = lux(basis), cutoff = WrappedFunction(x -> f_cut.(x))), env = WrappedFunction(x -> x[1].*x[2]), ), spec) end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index d319690..79bfe95 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -253,3 +253,12 @@ function degord2spec(radial::Radial_basis; totaldegree, order, Lmax, catagories end get_i(i) = WrappedFunction(t -> t[i]) + +function sparse_trans(pos,len::Int64) + @assert maximum(pos) <= len + A = sparse(zeros(Int64, length(pos),len)) + for i = 1:length(pos) + A[i,pos[i]] = 1 + end + return sparse(A) +end \ No newline at end of file diff --git a/test/Yangshuais_energy_est.jl b/test/Yangshuais_energy_est.jl new file mode 100644 index 0000000..fc5a267 --- /dev/null +++ b/test/Yangshuais_energy_est.jl @@ -0,0 +1,224 @@ +using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote +using Polynomials4ML: LinearLayer, RYlmBasis, lux, legendre_basis +using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA, simple_radial_basis +rng = Random.MersenneTwister() + +# dataset +using ASE, JuLIP +function gen_dat() + eam = JuLIP.Potentials.EAM("test/w_eam4.fs") + at = rattle!(bulk(:W, cubic=true) * 2, 0.2) + set_data!(at, "energy", energy(eam, at)) + return at +end +Random.seed!(0) +train = [gen_dat() for _ = 1:100]; + +rcut = 5.5 +maxL = 0 +totdeg = 10 +ord = 2 + +fcut(rcut::Float64,pin::Int=2,pout::Int=2) = r -> (r < rcut ? abs( (r/rcut)^pin - 1)^pout : 0) +ftrans(r0::Float64=2.0,p::Int=2) = r -> ( (1+r0)/(1+r) )^p +radial = simple_radial_basis(legendre_basis(totdeg),fcut(rcut),ftrans()) + +Aspec, AAspec = degord2spec(radial; totaldegree = totdeg, + order = ord, + Lmax = maxL, ) + +l_basis, ps_basis, st_basis = equivariant_model(AAspec, radial, maxL; islong = false) +X = [ @SVector(randn(3)) for i in 1:10 ] +B = l_basis(X, ps_basis, st_basis)[1] + +# now extend the above BB basis to a model +len_BB = length(B) + +model = append_layer(l_basis, WrappedFunction(t -> real(t)); l_name=:real) +model = append_layer(model, LinearLayer(len_BB, 1); l_name=:dot) +model = append_layer(model, WrappedFunction(t -> t[1]); l_name=:get1) + +ps, st = Lux.setup(rng, model) +out, st = model(X, ps, st) + +module Pot + import JuLIP, Zygote, StaticArrays + import JuLIP: cutoff, Atoms + import ACEbase: evaluate!, evaluate_d! + import StaticArrays: SVector, SMatrix + + import ChainRulesCore + import ChainRulesCore: rrule, ignore_derivatives + + import Optimisers: destructure + + struct LuxCalc <: JuLIP.SitePotential + luxmodel + ps + st + rcut::Float64 + restructure + end + + function LuxCalc(luxmodel, ps, st, rcut) + pvec, rest = destructure(ps) + return LuxCalc(luxmodel, ps, st, rcut, rest) + end + + cutoff(calc::LuxCalc) = calc.rcut + + function evaluate!(tmp, calc::LuxCalc, Rs, Zs, z0) + E, st = calc.luxmodel(Rs, calc.ps, calc.st) + return E[1] + end + + function evaluate_d!(dEs, tmpd, calc::LuxCalc, Rs, Zs, z0) + g = Zygote.gradient(X -> calc.luxmodel(X, calc.ps, calc.st)[1], Rs)[1] + @assert length(g) == length(Rs) <= length(dEs) + dEs[1:length(g)] .= g + return dEs + end + + # ----- parameter estimation stuff + + + function lux_energy(at::Atoms, calc::LuxCalc, ps::NamedTuple, st::NamedTuple) + nlist = ignore_derivatives() do + JuLIP.neighbourlist(at, calc.rcut) + end + return sum( i -> begin + Js, Rs, Zs = ignore_derivatives() do + JuLIP.Potentials.neigsz(nlist, at, i) + end + Ei, st = calc.luxmodel(Rs, ps, st) + Ei[1] + end, + 1:length(at) + ) + end + + + function lux_efv(at::Atoms, calc::LuxCalc, ps::NamedTuple, st::NamedTuple) + nlist = ignore_derivatives() do + JuLIP.neighbourlist(at, calc.rcut) + end + E = 0.0 + F = zeros(SVector{3, Float64}, length(at)) + V = zero(SMatrix{3, 3, Float64}) + for i = 1:length(at) + Js, Rs, Zs = ignore_derivatives() do + JuLIP.Potentials.neigsz(nlist, at, i) + end + comp = Zygote.withgradient(_X -> calc.luxmodel(_X, ps, st)[1], Rs) + Ei = comp.val + ∇Ei = comp.grad[1] + # energy + E += Ei + + # Forces + for j = 1:length(Rs) + F[Js[j]] -= ∇Ei[j] + F[i] += ∇Ei[j] + end + + # Virial + if length(Rs) > 0 + V -= sum(∇Eij * Rij' for (∇Eij, Rij) in zip(∇Ei, Rs)) + end + end + + return E, F, V + end + +end + +ps.dot.W[:] .= 0.01 * randn(length(ps.dot.W)) +calc = Pot.LuxCalc(model, ps, st, rcut) + +using Optimisers, ReverseDiff + +p_vec, _rest = destructure(ps) + +# energy loss function +function E_loss(train, calc, p_vec) + ps = _rest(p_vec) + st = calc.st + Eerr = 0 + for at in train + Nat = length(at) + Eref = at.data["energy"].data + E = Pot.lux_energy(at, calc, ps, st) + Eerr += ( (Eref - E) / Nat)^2 + end + return Eerr +end + +p0 = zero.(p_vec) +E_loss(train, calc, p0) +ReverseDiff.gradient(p -> E_loss(train, calc, p), p_vec) +Zygote.gradient(p -> E_loss(train, calc, p), p_vec)[1] + +using Optim +obj_f = x -> E_loss(train, calc, x) +obj_g! = (g, x) -> copyto!(g, ReverseDiff.gradient(p -> E_loss(train, calc, p), x)) +# obj_g! = (g, x) -> copyto!(g, Zygote.gradient(p -> E_loss(train, calc, p), x)[1]) + +using LineSearches: BackTracking +using LineSearches +# solver = Optim.ConjugateGradient()#linesearch = BackTracking(order=2, maxstep=Inf)) +# solver = Optim.GradientDescent() +solver = Optim.BFGS() +# solver = Optim.LBFGS() #alphaguess = LineSearches.InitialHagerZhang(), + # linesearch = BackTracking(order=2, maxstep=Inf) ) + +res = optimize(obj_f, obj_g!, p0, solver, + Optim.Options(g_tol = 1e-6, show_trace = true)) + +Eerrmin = Optim.minimum(res) +RMSE = sqrt(Eerrmin / length(train)) +pargmin = Optim.minimizer(res) +p1 = pargmin + +train = [gen_dat() for _ = 1:200]; +res_new = optimize(obj_f, obj_g!, p1, solver, + Optim.Options(g_tol = 1e-6, show_trace = true)) + +Eerrmin_new = Optim.minimum(res_new) +RMSE_new = sqrt(Eerrmin_new / length(train)) +pargmin_new = Optim.minimizer(res_new) +# plot the fitting result + +ace = Pot.LuxCalc(model, pargmin, st, rcut) +Eref = [] +Eace = [] +for tr in train + exact = tr.data["energy"].data + estim = Pot.lux_energy(tr, ace, _rest(pargmin_new), st) + push!(Eref, exact) + push!(Eace, estim) +end + +test = [gen_dat() for _ = 1:300]; +Eref_te = [] +Eace_te = [] +for te in test + exact = te.data["energy"].data + estim = Pot.lux_energy(te, ace, _rest(pargmin_new), st) + push!(Eref_te, exact) + push!(Eace_te, estim) +end + +MIN = Eref_te |> minimum +MAX = Eref_te |> maximum +using PyPlot +figure() +scatter(Eref, Eace, c="red", alpha=0.4) +scatter(Eref_te, Eace_te, c="blue", alpha=0.4) +plot(-142.3:0.01:-137.5, -142.3:0.01:-137.5, lw=2, c="k", ls="--") +PyPlot.legend(["Train", "Test"], fontsize=14, loc=2); +xlabel("Reference energy") +ylabel("ACE energy") +axis("square") +xlim([-142.3, -137.5]) +ylim([-142.3, -137.5]) +PyPlot.savefig("W_energy_fitting.png") \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 6d78f14..99fc0b2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,4 +8,7 @@ using Test include("test_equiv_with_cate.jl") include("test_rSH_equivariance.jl") end + @testset "Linear_Dependence" begin + include("test_linear_dependence.jl") + end end diff --git a/test/test_linear_dependence.jl b/test/test_linear_dependence.jl new file mode 100644 index 0000000..11aefab --- /dev/null +++ b/test/test_linear_dependence.jl @@ -0,0 +1,34 @@ +using EquivariantModels, StaticArrays, Test, Polynomials4ML, LinearAlgebra +using ACEbase.Testing: print_tf +using EquivariantModels: Radial_basis, degord2spec +using Polynomials4ML:lux +using RepLieGroups + +fcut(rcut::Float64,pin::Int=2,pout::Int=2) = r -> (r < rcut ? abs( (r/rcut)^pin - 1)^pout : 0) +ftrans(r0::Float64=.0,p::Int=2) = r -> ( (1+r0)/(1+r) )^p +rcut = 5.5 +L = 0 + +@info("Testing linear independence of the L = $L equivariant basis") + +for ord = 1:3 + for totdeg = 4:8 + radial = EquivariantModels.simple_radial_basis(MonoBasis(totdeg-1),r->sqrt(r)*fcut(rcut)(r),ftrans()) + + Aspec, AAspec = degord2spec(radial; totaldegree = totdeg, + order = ord, + Lmax = L, islong = false) + luxchain, ps, st = equivariant_model(AAspec, radial, L; islong = false) + F(X) = luxchain(X, ps, st)[1] + X = [ @SVector(rand(3)) for i in 1:10 ] + + T = L == 0 ? ComplexF64 : SVector{2L+1,ComplexF64} + A = zeros(T,length(F(X)),10length(F(X))) + for i = 1:10length(F(X)) + local x = [ @SVector(rand(3)) for i in 1:10 ] + A[:,i] = F(x) + end + print_tf(@test rank(A) == length(F(X))) + end +end +println() \ No newline at end of file