From 01bb9e06549db288e6e84e597f248846b49fbfc1 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Wed, 27 Sep 2023 15:25:13 -0700 Subject: [PATCH 01/20] Initialise ConstLinearLayer --- src/builder.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/builder.jl b/src/builder.jl index aa54760..6c35c1c 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -192,6 +192,15 @@ L : Largest equivariance level categories : A list of categories radial_basis : specified radial basis, default using P4ML.legendre_basis """ + +struct ConstLinearLayer <: AbstractExplicitLayer + in_dim::Integer + out_dim::Integer + use_cache::Bool + # parameters ? + @reqfields() +end + function equivariant_model(spec_nlm, L::Int64; categories=[], d=3, radial_basis=legendre_basis, group="O3", islong=true) # first filt out those unfeasible spec_nlm filter_init = islong ? RPE_filter_long(L) : RPE_filter(L) From 0047583a35414a25aabf3fae6d6f00a44cb0cabb Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Sun, 1 Oct 2023 15:42:31 -0700 Subject: [PATCH 02/20] keep it consistent to the Radial_Basis... branch --- src/builder.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/builder.jl b/src/builder.jl index 14c8e2c..965d599 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -201,7 +201,10 @@ struct ConstLinearLayer <: AbstractExplicitLayer @reqfields() end -function equivariant_model(spec_nlm, L::Int64; categories=[], d=3, radial_basis=legendre_basis, group="O3", islong=true) +function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true, rSH = false) + if rSH && L > 0 + error("rSH is only implemented (for now) for L = 0") + end # first filt out those unfeasible spec_nlm filter_init = rSH ? RPE_filter_real(L) : (islong ? RPE_filter_long(L) : RPE_filter(L)) From 8c5ed6c1172ea38d9c61738179c832e9b4c22cb3 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Mon, 2 Oct 2023 09:46:34 -0700 Subject: [PATCH 03/20] WIP - far from complete --- src/ConstLinearLayer.jl | 47 +++++++++++++++++++++++++++++++++++++++++ src/builder.jl | 11 +++------- 2 files changed, 50 insertions(+), 8 deletions(-) create mode 100644 src/ConstLinearLayer.jl diff --git a/src/ConstLinearLayer.jl b/src/ConstLinearLayer.jl new file mode 100644 index 0000000..e3f9c33 --- /dev/null +++ b/src/ConstLinearLayer.jl @@ -0,0 +1,47 @@ +import ChainRulesCore: rrule +using LuxCore +using LuxCore: AbstractExplicitLayer + +struct ConstLinearLayer{T} <: AbstractExplicitLayer # where {in_dim,out_dim,T} + W::AbstractMatrix{T} + in_dim::Integer + out_dim::Integer +end + +ConstLinearLayer(W::AbstractMatrix{T}) where T = ConstLinearLayer(W,size(W,2),size(W,1)) + +(l::ConstLinearLayer)(x::AbstractVector) = l.in_dim == length(x) ? l.W * x : error("x has a wrong length!") + +(l::ConstLinearLayer)(x::AbstractMatrix) = begin + Tmp = l(x[1,:]) + for i = 2:size(x,1) + Tmp = [Tmp l(x[i,:])] + end + return Tmp' + end + +function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractVector) + val = l(x) + function pb(A) + return NoTangent(), NoTangent(), l.W' * A[1], (W = A[1] * x',), NoTangent() + end + return val, pb +end + +(l::ConstLinearLayer)(x::AbstractArray,ps,st) = l(x) + +function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractArray,ps,st) + val = l(x) + function pb(A) + return NoTangent(), NoTangent(), l.W' * A[1], (W = A[1] * x',), NoTangent() + end + return val, pb +end + +# function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractMatrix, ps, st) +# val = l(x, ps, st) +# function pb(A) +# return NoTangent(), NoTangent(), l.W' * A[1], (W = A[1] * x',), NoTangent() +# end +# return val, pb +# end \ No newline at end of file diff --git a/src/builder.jl b/src/builder.jl index 965d599..da21e3b 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -193,13 +193,8 @@ L : Largest equivariance level categories : A list of categories radial_basis : specified radial basis, default using P4ML.legendre_basis """ -struct ConstLinearLayer <: AbstractExplicitLayer - in_dim::Integer - out_dim::Integer - use_cache::Bool - # parameters ? - @reqfields() -end + +include("ConstLinearLayer.jl") function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true, rSH = false) if rSH && L > 0 @@ -234,7 +229,7 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories= C = _rpi_A2B_matrix(cgen, spec_nlm) end - l_sym = islong ? Lux.Parallel(nothing, [WrappedFunction(x -> C[i] * x[pos[i]]) for i = 1:L+1]... ) : WrappedFunction(x -> C * x) + l_sym = islong ? Lux.Parallel(nothing, [WrappedFunction(x -> C[i] * x[pos[i]]) for i = 1:L+1]... ) : ConstLinearLayer(C) # WrappedFunction(x -> C * x) # TODO: make it a Const_LinearLayer instead # C - A2Bmap luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB) From 02e303deb52cf952812d2c87a8790d5c6f987f62 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Tue, 3 Oct 2023 18:17:45 -0700 Subject: [PATCH 04/20] partially works --- src/ConstLinearLayer.jl | 4 ++-- src/builder.jl | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/ConstLinearLayer.jl b/src/ConstLinearLayer.jl index e3f9c33..10d3255 100644 --- a/src/ConstLinearLayer.jl +++ b/src/ConstLinearLayer.jl @@ -28,10 +28,10 @@ function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractVector) return val, pb end -(l::ConstLinearLayer)(x::AbstractArray,ps,st) = l(x) +(l::ConstLinearLayer)(x::AbstractArray,ps,st) = (l(x), st) function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractArray,ps,st) - val = l(x) + val = l(x,ps,st) function pb(A) return NoTangent(), NoTangent(), l.W' * A[1], (W = A[1] * x',), NoTangent() end diff --git a/src/builder.jl b/src/builder.jl index da21e3b..9221188 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -229,7 +229,8 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories= C = _rpi_A2B_matrix(cgen, spec_nlm) end - l_sym = islong ? Lux.Parallel(nothing, [WrappedFunction(x -> C[i] * x[pos[i]]) for i = 1:L+1]... ) : ConstLinearLayer(C) # WrappedFunction(x -> C * x) + # l_sym = islong ? Lux.Parallel(nothing, [WrappedFunction(x -> C[i] * x[pos[i]]) for i = 1:L+1]... ) : WrappedFunction(x -> C * x) + l_sym = islong ? Lux.Parallel(nothing, [WrappedFunction(x -> C[i] * x[pos[i]]) for i = 1:L+1]... ) : ConstLinearLayer(C) # TODO: make it a Const_LinearLayer instead # C - A2Bmap luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB) From 8e7fcf0ea7813a2319208a5386a8050b91327c30 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Tue, 3 Oct 2023 18:55:32 -0700 Subject: [PATCH 05/20] Clean up --- src/ConstLinearLayer.jl | 6 ++++-- src/builder.jl | 6 ++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/ConstLinearLayer.jl b/src/ConstLinearLayer.jl index 10d3255..22714a1 100644 --- a/src/ConstLinearLayer.jl +++ b/src/ConstLinearLayer.jl @@ -4,13 +4,15 @@ using LuxCore: AbstractExplicitLayer struct ConstLinearLayer{T} <: AbstractExplicitLayer # where {in_dim,out_dim,T} W::AbstractMatrix{T} + position::Union{Vector{Int64}, UnitRange{Int64}} in_dim::Integer out_dim::Integer end -ConstLinearLayer(W::AbstractMatrix{T}) where T = ConstLinearLayer(W,size(W,2),size(W,1)) +ConstLinearLayer(W::AbstractMatrix{T}) where T = ConstLinearLayer(W,1:size(W,2),size(W,2),size(W,1)) +ConstLinearLayer(W::AbstractMatrix{T}, pos::Union{Vector{Int64}, UnitRange{Int64}}) where T = ConstLinearLayer(W,pos,size(W,2),size(W,1)) -(l::ConstLinearLayer)(x::AbstractVector) = l.in_dim == length(x) ? l.W * x : error("x has a wrong length!") +(l::ConstLinearLayer)(x::AbstractVector) = l.in_dim == length(x[l.position]) ? l.W * x[l.position] : error("x (or the position index) has a wrong length!") (l::ConstLinearLayer)(x::AbstractMatrix) = begin Tmp = l(x[1,:]) diff --git a/src/builder.jl b/src/builder.jl index 9221188..0cf749b 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -229,9 +229,7 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories= C = _rpi_A2B_matrix(cgen, spec_nlm) end - # l_sym = islong ? Lux.Parallel(nothing, [WrappedFunction(x -> C[i] * x[pos[i]]) for i = 1:L+1]... ) : WrappedFunction(x -> C * x) - l_sym = islong ? Lux.Parallel(nothing, [WrappedFunction(x -> C[i] * x[pos[i]]) for i = 1:L+1]... ) : ConstLinearLayer(C) - # TODO: make it a Const_LinearLayer instead + l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(C[i],pos[i]) for i = 1:L+1]... ) : ConstLinearLayer(C) # C - A2Bmap luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB) # luxchain = Chain(xx2AA = luxchain_tmp, BB = l_sym) @@ -272,7 +270,7 @@ function equivariant_SYY_model(spec_nlm, radial::Radial_basis, L::Int64; categor cgen = Rot3DCoeffs_long(L) # TODO: this should be made group related C = _rpi_A2B_matrix(cgen, spec_nlm) - l_sym = WrappedFunction(x -> C * x) + l_sym = ConstLinearLayer(C) # C - A2Bmap luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB) From 53a472d03f8761c4934bc17c0e4148f73199e6c2 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Tue, 3 Oct 2023 23:48:40 -0700 Subject: [PATCH 06/20] Adapt Yangshuai's energy fitting code to the latest version --- test/Yangshuais_energy_est.jl | 222 ++++++++++++++++++++++++++++++++++ 1 file changed, 222 insertions(+) create mode 100644 test/Yangshuais_energy_est.jl diff --git a/test/Yangshuais_energy_est.jl b/test/Yangshuais_energy_est.jl new file mode 100644 index 0000000..87b619d --- /dev/null +++ b/test/Yangshuais_energy_est.jl @@ -0,0 +1,222 @@ +using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Zygote +using Polynomials4ML: LinearLayer, RYlmBasis, lux, legendre_basis +using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA, simple_radial_basis +rng = Random.MersenneTwister() + +# dataset +using ASE, JuLIP +function gen_dat() + eam = JuLIP.Potentials.EAM("test/w_eam4.fs") + at = rattle!(bulk(:W, cubic=true) * 2, 0.1) + set_data!(at, "energy", energy(eam, at)) + return at +end +Random.seed!(0) +train = [gen_dat() for _ = 1:100]; + +rcut = 5.5 +maxL = 0 +totdeg = 6 +ord = 3 + +fcut(rcut::Float64,pin::Int=2,pout::Int=2) = r -> (r < rcut ? abs( (r/rcut)^pin - 1)^pout : 0) +ftrans(r0::Float64=.0,p::Int=2) = r -> ( (1+r0)/(1+r) )^p +radial = simple_radial_basis(legendre_basis(totdeg),fcut(rcut),ftrans()) + +Aspec, AAspec = degord2spec(radial; totaldegree = totdeg, + order = ord, + Lmax = maxL, ) + +l_basis, ps_basis, st_basis = equivariant_model(AAspec, radial, maxL; islong = false) +X = [ @SVector(randn(3)) for i in 1:10 ] +B = l_basis(X, ps_basis, st_basis)[1] + +# now extend the above BB basis to a model +len_BB = length(B) + +model = append_layer(l_basis, WrappedFunction(t -> real(t)); l_name=:real) +model = append_layer(model, LinearLayer(len_BB, 1); l_name=:dot) +model = append_layer(model, WrappedFunction(t -> t[1]); l_name=:get1) + +ps, st = Lux.setup(rng, model) +out, st = model(X, ps, st) + +module Pot + import JuLIP, Zygote, StaticArrays + import JuLIP: cutoff, Atoms + import ACEbase: evaluate!, evaluate_d! + import StaticArrays: SVector, SMatrix + + import ChainRulesCore + import ChainRulesCore: rrule, ignore_derivatives + + import Optimisers: destructure + + struct LuxCalc <: JuLIP.SitePotential + luxmodel + ps + st + rcut::Float64 + restructure + end + + function LuxCalc(luxmodel, ps, st, rcut) + pvec, rest = destructure(ps) + return LuxCalc(luxmodel, ps, st, rcut, rest) + end + + cutoff(calc::LuxCalc) = calc.rcut + + function evaluate!(tmp, calc::LuxCalc, Rs, Zs, z0) + E, st = calc.luxmodel(Rs, calc.ps, calc.st) + return E[1] + end + + function evaluate_d!(dEs, tmpd, calc::LuxCalc, Rs, Zs, z0) + g = Zygote.gradient(X -> calc.luxmodel(X, calc.ps, calc.st)[1], Rs)[1] + @assert length(g) == length(Rs) <= length(dEs) + dEs[1:length(g)] .= g + return dEs + end + + # ----- parameter estimation stuff + + + function lux_energy(at::Atoms, calc::LuxCalc, ps::NamedTuple, st::NamedTuple) + nlist = ignore_derivatives() do + JuLIP.neighbourlist(at, calc.rcut) + end + return sum( i -> begin + Js, Rs, Zs = ignore_derivatives() do + JuLIP.Potentials.neigsz(nlist, at, i) + end + Ei, st = calc.luxmodel(Rs, ps, st) + Ei[1] + end, + 1:length(at) + ) + end + + + function lux_efv(at::Atoms, calc::LuxCalc, ps::NamedTuple, st::NamedTuple) + nlist = ignore_derivatives() do + JuLIP.neighbourlist(at, calc.rcut) + end + E = 0.0 + F = zeros(SVector{3, Float64}, length(at)) + V = zero(SMatrix{3, 3, Float64}) + for i = 1:length(at) + Js, Rs, Zs = ignore_derivatives() do + JuLIP.Potentials.neigsz(nlist, at, i) + end + comp = Zygote.withgradient(_X -> calc.luxmodel(_X, ps, st)[1], Rs) + Ei = comp.val + ∇Ei = comp.grad[1] + # energy + E += Ei + + # Forces + for j = 1:length(Rs) + F[Js[j]] -= ∇Ei[j] + F[i] += ∇Ei[j] + end + + # Virial + if length(Rs) > 0 + V -= sum(∇Eij * Rij' for (∇Eij, Rij) in zip(∇Ei, Rs)) + end + end + + return E, F, V + end + +end + +ps.dot.W[:] .= 0.01 * randn(length(ps.dot.W)) +calc = Pot.LuxCalc(model, ps, st, rcut) + +using Optimisers, ReverseDiff + +p_vec, _rest = destructure(ps) + +# energy loss function +function E_loss(train, calc, p_vec) + ps = _rest(p_vec) + st = calc.st + Eerr = 0 + for at in train + Nat = length(at) + Eref = at.data["energy"].data + E = Pot.lux_energy(at, calc, ps, st) + Eerr += ( (Eref - E) / Nat)^2 + end + return Eerr +end + +p0 = zero.(p_vec) +E_loss(train, calc, p0) +ReverseDiff.gradient(p -> E_loss(train, calc, p), p_vec) +Zygote.gradient(p -> E_loss(train, calc, p), p_vec)[1] + +using Optim +obj_f = x -> E_loss(train, calc, x) +obj_g! = (g, x) -> copyto!(g, ReverseDiff.gradient(p -> E_loss(train, calc, p), x)) +# obj_g! = (g, x) -> copyto!(g, Zygote.gradient(p -> E_loss(train, calc, p), x)[1]) + +using LineSearches: BackTracking +using LineSearches +# solver = Optim.ConjugateGradient()#linesearch = BackTracking(order=2, maxstep=Inf)) +# solver = Optim.GradientDescent(linesearch = BackTracking(order=2, maxstep=Inf) ) +# solver = Optim.BFGS() +solver = Optim.LBFGS() #alphaguess = LineSearches.InitialHagerZhang(), + # linesearch = BackTracking(order=2, maxstep=Inf) ) + +res = optimize(obj_f, obj_g!, p0, solver, + Optim.Options(f_tol = 1e-10, g_tol = 1e-6, show_trace = true)) + +Eerrmin = Optim.minimum(res) +RMSE = sqrt(Eerrmin / length(train)) +pargmin = Optim.minimizer(res) +p1 = pargmin + +train = [gen_dat() for _ = 1:500]; +res_new = optimize(obj_f, obj_g!, p1, solver, + Optim.Options(f_tol = 1e-10, g_tol = 1e-5, show_trace = true)) + +Eerrmin_new = Optim.minimum(res_new) +RMSE_new = sqrt(Eerrmin_new / length(train)) +pargmin_new = Optim.minimizer(res_new) +# plot the fitting result + +ace = Pot.LuxCalc(model, pargmin, st, rcut) +Eref = [] +Eace = [] +for tr in train + exact = tr.data["energy"].data + estim = Pot.lux_energy(tr, ace, _rest(pargmin_new), st) + push!(Eref, exact) + push!(Eace, estim) +end + +test = [gen_dat() for _ = 1:300]; +Eref_te = [] +Eace_te = [] +for te in test + exact = te.data["energy"].data + estim = Pot.lux_energy(te, ace, _rest(pargmin_new), st) + push!(Eref_te, exact) + push!(Eace_te, estim) +end + +using PyPlot +figure() +scatter(Eref, Eace, c="red", alpha=0.4) +scatter(Eref_te, Eace_te, c="blue", alpha=0.4) +plot(-142.3:0.01:-141.5, -142.3:0.01:-141.5, lw=2, c="k", ls="--") +PyPlot.legend(["Train", "Test"], fontsize=14, loc=2); +xlabel("Reference energy") +ylabel("ACE energy") +axis("square") +xlim([-142.3, -141.5]) +ylim([-142.3, -141.5]) +PyPlot.savefig("W_energy_fitting.png") \ No newline at end of file From 4a7a06d04a6986c581cecfb79b3fbb6ea2b7dd79 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Wed, 4 Oct 2023 22:14:42 -0700 Subject: [PATCH 07/20] introduce rpe_basis & a linear dependence test --- src/builder.jl | 35 ++++++++++++++++++++++--- test/test_linear_dependence.jl | 47 ++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 3 deletions(-) create mode 100644 test/test_linear_dependence.jl diff --git a/src/builder.jl b/src/builder.jl index fbd9965..0a35c50 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -1,6 +1,6 @@ using LinearAlgebra using SparseArrays: SparseMatrixCSC, sparse -using RepLieGroups.O3: Rot3DCoeffs, Rot3DCoeffs_real, Rot3DCoeffs_long, re_basis, SYYVector, mm_filter +using RepLieGroups.O3: Rot3DCoeffs, Rot3DCoeffs_real, Rot3DCoeffs_long, re_basis, SYYVector, mm_filter, coco_dot using Polynomials4ML: legendre_basis, RYlmBasis, natural_indices, degree using Polynomials4ML.Utils: gensparse using Lux: WrappedFunction @@ -8,6 +8,7 @@ using Lux using Random using Polynomials4ML using StaticArrays +using Combinatorics export equivariant_model, equivariant_SYY_model, equivariant_luxchain_constructor, equivariant_luxchain_constructor_new @@ -23,6 +24,33 @@ _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T},Rot3DCoeffs_real{L,T},Rot3DCoeffs_l spec::Vector{Vector{NamedTuple}}) Return a sparse matrix for symmetrisation of AA basis of spec with equivariance specified by cgen """ +function rpe_basis(A::Union{Rot3DCoeffs,Rot3DCoeffs_long,Rot3DCoeffs_real}, nn::SVector{N, TN}, ll::SVector{N, Int}) where {N, TN} + Ure, Mre = re_basis(A, ll) + G = _gramian(nn, ll, Ure, Mre) + S = svd(G) + rk = rank(Diagonal(S.S); rtol = 1e-7) + Urpe = S.U[:, 1:rk]' + return Diagonal(sqrt.(S.S[1:rk])) * Urpe * Ure, Mre +end + + +function _gramian(nn, ll, Ure, Mre) + N = length(nn) + nre = size(Ure, 1) + G = zeros(Complex{Float64}, nre, nre) + for σ in permutations(1:N) + if (nn[σ] != nn) || (ll[σ] != ll); continue; end + for (iU1, mm1) in enumerate(Mre), (iU2, mm2) in enumerate(Mre) + if mm1[σ] == mm2 + for i1 = 1:nre, i2 = 1:nre + G[i1, i2] += coco_dot(Ure[i1, iU1], Ure[i2, iU2]) + end + end + end + end + return G +end + function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Rot3DCoeffs_long{L,T}}, spec) where {L,T} # allocate triplet format Irow, Jcol = Int[], Int[] @@ -55,7 +83,7 @@ function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Ro if (nn,ll,ss) in nnllset; continue; end # get the Mll indices and coeffs - U, Mll = re_basis(cgen, ll) + U, Mll = rpe_basis(cgen, nn, ll) # conver the Mlls into basis functions (NamedTuples) rpibs = [_nlms2b(nn, ll, mm, ss) for mm in Mll] @@ -83,7 +111,8 @@ function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Ro if (nn,ll) in nnllset; continue; end # get the Mll indices and coeffs - U, Mll = re_basis(cgen, ll) + # U, Mll = re_basis(cgen, ll) + U, Mll = rpe_basis(cgen, nn, ll) # conver the Mlls into basis functions (NamedTuples) rpibs = [_nlms2b(nn, ll, mm) for mm in Mll] diff --git a/test/test_linear_dependence.jl b/test/test_linear_dependence.jl new file mode 100644 index 0000000..0b10d59 --- /dev/null +++ b/test/test_linear_dependence.jl @@ -0,0 +1,47 @@ +using EquivariantModels, StaticArrays, Test, Polynomials4ML, LinearAlgebra +using ACEbase.Testing: print_tf +using Rotations, WignerD, BlockDiagonals +using EquivariantModels: Radial_basis, xx2AA, degord2spec +using Polynomials4ML:lux +using RepLieGroups + +include("wigner.jl") + +@info("Testing the chain that generates a single B basis") +rcut = 5.5 +totdeg = 6 +ν = 3 +Lmax = 2 +fcut(rcut::Float64,pin::Int=2,pout::Int=2) = r -> (r < rcut ? abs( (r/rcut)^pin - 1)^pout : 0) +ftrans(r0::Float64=.0,p::Int=2) = r -> ( (1+r0)/(1+r) )^p +radial = EquivariantModels.simple_radial_basis(MonoBasis(totdeg-1),r->sqrt(r)*fcut(rcut,2,2)(r),r->1/sqrt(r)*ftrans(1.0,2)(r)) + +L = 0 +Aspec, AAspec = degord2spec(radial; totaldegree = totdeg, + order = ν, + Lmax = L, islong = false) +# luxchain, ps, st = xx2AA(AAspec, radial) +luxchain, ps, st = equivariant_model(AAspec, radial, L; islong = false) +F(X) = luxchain(X, ps, st)[1] + +X = [ @SVector(rand(3)) for i in 1:10 ] +F(X) + +T = L == 0 ? ComplexF64 : SVector{2L+1,ComplexF64} +A = zeros(T,length(F(X)),3length(F(X))) +for i = 1:3length(F(X)) + local x = [ @SVector(rand(3)) for i in 1:10 ] + A[:,i] = F(x) +end +B = +try + A*A' +catch + B = zeros(ComplexF64,length(F(X)),length(F(X))) + for i = 1:length(F(X)) + for j = 1:length(F(X)) + B[i,j] = sum(RepLieGroups.O3.coco_dot(A[i,t], A[j,t]) for t = 1:3length(F(X))) + end + end +end +rank(B) \ No newline at end of file From 7252202ee7060dc8213e6e0711e384c847a5b8f4 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Thu, 5 Oct 2023 00:06:41 -0700 Subject: [PATCH 08/20] Fix the linear dependence issue --- src/radial.jl | 6 ++-- test/test_linear_dependence.jl | 52 +++++++++++++--------------------- 2 files changed, 22 insertions(+), 36 deletions(-) diff --git a/src/radial.jl b/src/radial.jl index ecee4db..f51c738 100644 --- a/src/radial.jl +++ b/src/radial.jl @@ -22,7 +22,7 @@ Radial_basis(Rnl::AbstractExplicitLayer) = end # it is in its current form just for the purpose of testing - a more specific example can be found in forces.jl -function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=r->1,f_trans::Function=r->1; spec = nothing) +function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=r->1,f_trans::Function=r->r; spec = nothing) if isnothing(spec) try spec = natural_indices(basis) @@ -30,8 +30,6 @@ function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=r->1,f_tr error("The specification of this Radial_basis should be given explicitly!") end end - - f(r) = f_trans(r) * f_cut(r) - return Radial_basis(Chain(getnorm = WrappedFunction(x -> norm.(x)), trans = WrappedFunction(x -> f.(x)), poly = lux(basis), ), spec) + return Radial_basis(Chain(trans = WrappedFunction(x -> f_trans.(norm.(x))), evluation = Lux.BranchLayer(poly = lux(basis), cutoff = WrappedFunction(x -> f_cut.(x))), env = WrappedFunction(x -> x[1].*x[2]), ), spec) end \ No newline at end of file diff --git a/test/test_linear_dependence.jl b/test/test_linear_dependence.jl index 0b10d59..6f79419 100644 --- a/test/test_linear_dependence.jl +++ b/test/test_linear_dependence.jl @@ -5,43 +5,31 @@ using EquivariantModels: Radial_basis, xx2AA, degord2spec using Polynomials4ML:lux using RepLieGroups -include("wigner.jl") - -@info("Testing the chain that generates a single B basis") -rcut = 5.5 -totdeg = 6 -ν = 3 -Lmax = 2 fcut(rcut::Float64,pin::Int=2,pout::Int=2) = r -> (r < rcut ? abs( (r/rcut)^pin - 1)^pout : 0) ftrans(r0::Float64=.0,p::Int=2) = r -> ( (1+r0)/(1+r) )^p -radial = EquivariantModels.simple_radial_basis(MonoBasis(totdeg-1),r->sqrt(r)*fcut(rcut,2,2)(r),r->1/sqrt(r)*ftrans(1.0,2)(r)) - +rcut = 5.5 L = 0 -Aspec, AAspec = degord2spec(radial; totaldegree = totdeg, - order = ν, - Lmax = L, islong = false) -# luxchain, ps, st = xx2AA(AAspec, radial) -luxchain, ps, st = equivariant_model(AAspec, radial, L; islong = false) -F(X) = luxchain(X, ps, st)[1] -X = [ @SVector(rand(3)) for i in 1:10 ] -F(X) +@info("Testing linear independence of the L = $L equivariant basis") -T = L == 0 ? ComplexF64 : SVector{2L+1,ComplexF64} -A = zeros(T,length(F(X)),3length(F(X))) -for i = 1:3length(F(X)) - local x = [ @SVector(rand(3)) for i in 1:10 ] - A[:,i] = F(x) -end -B = -try - A*A' -catch - B = zeros(ComplexF64,length(F(X)),length(F(X))) - for i = 1:length(F(X)) - for j = 1:length(F(X)) - B[i,j] = sum(RepLieGroups.O3.coco_dot(A[i,t], A[j,t]) for t = 1:3length(F(X))) +for ord = 1:3 + for totdeg = 4:8 + radial = EquivariantModels.simple_radial_basis(MonoBasis(totdeg-1),r->sqrt(r)*fcut(rcut)(r),ftrans()) + + Aspec, AAspec = degord2spec(radial; totaldegree = totdeg, + order = ord, + Lmax = L, islong = false) + luxchain, ps, st = equivariant_model(AAspec, radial, L; islong = false) + F(X) = luxchain(X, ps, st)[1] + X = [ @SVector(rand(3)) for i in 1:10 ] + + T = L == 0 ? ComplexF64 : SVector{2L+1,ComplexF64} + A = zeros(T,length(F(X)),10length(F(X))) + for i = 1:10length(F(X)) + local x = [ @SVector(rand(3)) for i in 1:10 ] + A[:,i] = F(x) end + print_tf(@test rank(A) == length(F(X))) end end -rank(B) \ No newline at end of file +println() \ No newline at end of file From 7dc53bce180c06451e0e54a8b37d8a4b07a9db4e Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Thu, 5 Oct 2023 00:11:22 -0700 Subject: [PATCH 09/20] Update Project.toml --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index a494c1a..088c94e 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.0.2" ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e" BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" From 92b2b3d27148479dffba74f4b44cfe4256d8990e Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Thu, 5 Oct 2023 00:19:55 -0700 Subject: [PATCH 10/20] Typo fix --- src/radial.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/radial.jl b/src/radial.jl index f51c738..e305349 100644 --- a/src/radial.jl +++ b/src/radial.jl @@ -31,5 +31,5 @@ function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=r->1,f_tr end end - return Radial_basis(Chain(trans = WrappedFunction(x -> f_trans.(norm.(x))), evluation = Lux.BranchLayer(poly = lux(basis), cutoff = WrappedFunction(x -> f_cut.(x))), env = WrappedFunction(x -> x[1].*x[2]), ), spec) + return Radial_basis(Chain(trans = WrappedFunction(x -> f_trans.(norm.(x))), evaluation = Lux.BranchLayer(poly = lux(basis), cutoff = WrappedFunction(x -> f_cut.(x))), env = WrappedFunction(x -> x[1].*x[2]), ), spec) end \ No newline at end of file From 43d1a4f03bbc92c155eb613adae4a992440ee049 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Thu, 5 Oct 2023 00:24:54 -0700 Subject: [PATCH 11/20] Add the corresponding tests --- test/runtests.jl | 3 +++ test/test_linear_dependence.jl | 3 +-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 6d78f14..abce0e8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,4 +8,7 @@ using Test include("test_equiv_with_cate.jl") include("test_rSH_equivariance.jl") end + @testset "Linear_Dependence" begin + include("test_linear_dependence") + end end diff --git a/test/test_linear_dependence.jl b/test/test_linear_dependence.jl index 6f79419..11aefab 100644 --- a/test/test_linear_dependence.jl +++ b/test/test_linear_dependence.jl @@ -1,7 +1,6 @@ using EquivariantModels, StaticArrays, Test, Polynomials4ML, LinearAlgebra using ACEbase.Testing: print_tf -using Rotations, WignerD, BlockDiagonals -using EquivariantModels: Radial_basis, xx2AA, degord2spec +using EquivariantModels: Radial_basis, degord2spec using Polynomials4ML:lux using RepLieGroups From ba4027669b06090e4ea8fe35be62d77b6493df4c Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Thu, 5 Oct 2023 00:36:46 -0700 Subject: [PATCH 12/20] typo fix --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index abce0e8..99fc0b2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,6 +9,6 @@ using Test include("test_rSH_equivariance.jl") end @testset "Linear_Dependence" begin - include("test_linear_dependence") + include("test_linear_dependence.jl") end end From a115a63a9fc69bd12a4e58cdc29959ea8906ced5 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Thu, 5 Oct 2023 10:28:06 -0700 Subject: [PATCH 13/20] A simple energy test that checks the "completeness" of the basis for single species --- test/Yangshuais_energy_est.jl | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/test/Yangshuais_energy_est.jl b/test/Yangshuais_energy_est.jl index 87b619d..fc5a267 100644 --- a/test/Yangshuais_energy_est.jl +++ b/test/Yangshuais_energy_est.jl @@ -7,7 +7,7 @@ rng = Random.MersenneTwister() using ASE, JuLIP function gen_dat() eam = JuLIP.Potentials.EAM("test/w_eam4.fs") - at = rattle!(bulk(:W, cubic=true) * 2, 0.1) + at = rattle!(bulk(:W, cubic=true) * 2, 0.2) set_data!(at, "energy", energy(eam, at)) return at end @@ -16,11 +16,11 @@ train = [gen_dat() for _ = 1:100]; rcut = 5.5 maxL = 0 -totdeg = 6 -ord = 3 +totdeg = 10 +ord = 2 fcut(rcut::Float64,pin::Int=2,pout::Int=2) = r -> (r < rcut ? abs( (r/rcut)^pin - 1)^pout : 0) -ftrans(r0::Float64=.0,p::Int=2) = r -> ( (1+r0)/(1+r) )^p +ftrans(r0::Float64=2.0,p::Int=2) = r -> ( (1+r0)/(1+r) )^p radial = simple_radial_basis(legendre_basis(totdeg),fcut(rcut),ftrans()) Aspec, AAspec = degord2spec(radial; totaldegree = totdeg, @@ -166,22 +166,22 @@ obj_g! = (g, x) -> copyto!(g, ReverseDiff.gradient(p -> E_loss(train, calc, p), using LineSearches: BackTracking using LineSearches # solver = Optim.ConjugateGradient()#linesearch = BackTracking(order=2, maxstep=Inf)) -# solver = Optim.GradientDescent(linesearch = BackTracking(order=2, maxstep=Inf) ) -# solver = Optim.BFGS() -solver = Optim.LBFGS() #alphaguess = LineSearches.InitialHagerZhang(), +# solver = Optim.GradientDescent() +solver = Optim.BFGS() +# solver = Optim.LBFGS() #alphaguess = LineSearches.InitialHagerZhang(), # linesearch = BackTracking(order=2, maxstep=Inf) ) res = optimize(obj_f, obj_g!, p0, solver, - Optim.Options(f_tol = 1e-10, g_tol = 1e-6, show_trace = true)) + Optim.Options(g_tol = 1e-6, show_trace = true)) Eerrmin = Optim.minimum(res) RMSE = sqrt(Eerrmin / length(train)) pargmin = Optim.minimizer(res) p1 = pargmin -train = [gen_dat() for _ = 1:500]; +train = [gen_dat() for _ = 1:200]; res_new = optimize(obj_f, obj_g!, p1, solver, - Optim.Options(f_tol = 1e-10, g_tol = 1e-5, show_trace = true)) + Optim.Options(g_tol = 1e-6, show_trace = true)) Eerrmin_new = Optim.minimum(res_new) RMSE_new = sqrt(Eerrmin_new / length(train)) @@ -208,15 +208,17 @@ for te in test push!(Eace_te, estim) end +MIN = Eref_te |> minimum +MAX = Eref_te |> maximum using PyPlot figure() scatter(Eref, Eace, c="red", alpha=0.4) scatter(Eref_te, Eace_te, c="blue", alpha=0.4) -plot(-142.3:0.01:-141.5, -142.3:0.01:-141.5, lw=2, c="k", ls="--") +plot(-142.3:0.01:-137.5, -142.3:0.01:-137.5, lw=2, c="k", ls="--") PyPlot.legend(["Train", "Test"], fontsize=14, loc=2); xlabel("Reference energy") ylabel("ACE energy") axis("square") -xlim([-142.3, -141.5]) -ylim([-142.3, -141.5]) +xlim([-142.3, -137.5]) +ylim([-142.3, -137.5]) PyPlot.savefig("W_energy_fitting.png") \ No newline at end of file From b36359694aa093d6e69cbd22fd7634d0117fc2e9 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Thu, 5 Oct 2023 12:49:33 -0700 Subject: [PATCH 14/20] Resolve most of the issues in comments --- src/ConstLinearLayer.jl | 12 +++++------- src/builder.jl | 1 + 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/ConstLinearLayer.jl b/src/ConstLinearLayer.jl index 22714a1..4bba806 100644 --- a/src/ConstLinearLayer.jl +++ b/src/ConstLinearLayer.jl @@ -2,17 +2,15 @@ import ChainRulesCore: rrule using LuxCore using LuxCore: AbstractExplicitLayer -struct ConstLinearLayer{T} <: AbstractExplicitLayer # where {in_dim,out_dim,T} - W::AbstractMatrix{T} +struct ConstLinearLayer <: AbstractExplicitLayer # where {in_dim,out_dim,T} + W #::AbstractMatrix{T} position::Union{Vector{Int64}, UnitRange{Int64}} - in_dim::Integer - out_dim::Integer end -ConstLinearLayer(W::AbstractMatrix{T}) where T = ConstLinearLayer(W,1:size(W,2),size(W,2),size(W,1)) -ConstLinearLayer(W::AbstractMatrix{T}, pos::Union{Vector{Int64}, UnitRange{Int64}}) where T = ConstLinearLayer(W,pos,size(W,2),size(W,1)) +ConstLinearLayer(W) where T = ConstLinearLayer(W,1:size(W,2)) +# ConstLinearLayer(W, pos::Union{Vector{Int64}, UnitRange{Int64}}) = ConstLinearLayer(W,pos) -(l::ConstLinearLayer)(x::AbstractVector) = l.in_dim == length(x[l.position]) ? l.W * x[l.position] : error("x (or the position index) has a wrong length!") +(l::ConstLinearLayer)(x::AbstractVector) = l.W * x[l.position] (l::ConstLinearLayer)(x::AbstractMatrix) = begin Tmp = l(x[1,:]) diff --git a/src/builder.jl b/src/builder.jl index 0a35c50..1242e37 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -261,6 +261,7 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories= C = _rpi_A2B_matrix(cgen, spec_nlm) end + #TODO:make use [ C[i], pos[i] ] to generate another sparse matrix so that... l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(C[i],pos[i]) for i = 1:L+1]... ) : ConstLinearLayer(C) # C - A2Bmap luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB) From ae937d125c281bb8fa8f96306102ae2b86483110 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Thu, 5 Oct 2023 12:54:21 -0700 Subject: [PATCH 15/20] Renaming W to op to avoid ambiguity of learnable weigh and constant multiplier --- src/ConstLinearLayer.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/ConstLinearLayer.jl b/src/ConstLinearLayer.jl index 4bba806..5dcbb0d 100644 --- a/src/ConstLinearLayer.jl +++ b/src/ConstLinearLayer.jl @@ -3,14 +3,14 @@ using LuxCore using LuxCore: AbstractExplicitLayer struct ConstLinearLayer <: AbstractExplicitLayer # where {in_dim,out_dim,T} - W #::AbstractMatrix{T} + op #::AbstractMatrix{T} position::Union{Vector{Int64}, UnitRange{Int64}} end -ConstLinearLayer(W) where T = ConstLinearLayer(W,1:size(W,2)) -# ConstLinearLayer(W, pos::Union{Vector{Int64}, UnitRange{Int64}}) = ConstLinearLayer(W,pos) +ConstLinearLayer(op) = ConstLinearLayer(op,1:size(op,2)) +# ConstLinearLayer(op, pos::Union{Vector{Int64}, UnitRange{Int64}}) = ConstLinearLayer(op,pos) -(l::ConstLinearLayer)(x::AbstractVector) = l.W * x[l.position] +(l::ConstLinearLayer)(x::AbstractVector) = l.op * x[l.position] (l::ConstLinearLayer)(x::AbstractMatrix) = begin Tmp = l(x[1,:]) @@ -23,7 +23,7 @@ ConstLinearLayer(W) where T = ConstLinearLayer(W,1:size(W,2)) function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractVector) val = l(x) function pb(A) - return NoTangent(), NoTangent(), l.W' * A[1], (W = A[1] * x',), NoTangent() + return NoTangent(), NoTangent(), l.op' * A[1], (op = A[1] * x',), NoTangent() end return val, pb end @@ -33,7 +33,7 @@ end function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractArray,ps,st) val = l(x,ps,st) function pb(A) - return NoTangent(), NoTangent(), l.W' * A[1], (W = A[1] * x',), NoTangent() + return NoTangent(), NoTangent(), l.op' * A[1], (op = A[1] * x',), NoTangent() end return val, pb end @@ -41,7 +41,7 @@ end # function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractMatrix, ps, st) # val = l(x, ps, st) # function pb(A) -# return NoTangent(), NoTangent(), l.W' * A[1], (W = A[1] * x',), NoTangent() +# return NoTangent(), NoTangent(), l.op' * A[1], (op = A[1] * x',), NoTangent() # end # return val, pb # end \ No newline at end of file From 149447d416b2e40cfb5ac0be080116944adfe89b Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Thu, 5 Oct 2023 13:50:57 -0700 Subject: [PATCH 16/20] get rid of the "position" projectin --- src/builder.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/builder.jl b/src/builder.jl index 1242e37..2bbab0c 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -262,7 +262,7 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories= end #TODO:make use [ C[i], pos[i] ] to generate another sparse matrix so that... - l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(C[i],pos[i]) for i = 1:L+1]... ) : ConstLinearLayer(C) + l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(new_sparse_matrix(C[i],pos[i])) for i = 1:L+1]... ) : ConstLinearLayer(C) # C - A2Bmap luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB) # luxchain = Chain(xx2AA = luxchain_tmp, BB = l_sym) @@ -272,6 +272,15 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories= return luxchain, ps, st end +function new_sparse_matrix(C,pos) + col = maximum(pos) + C_new = sparse(zeros(typeof(C[1]),size(C,1),col)) + for i = 1:size(C,1) + C_new[i,pos] = C[i,:] + end + return C_new +end + # more constructors equivariant_model equivariant_model(totdeg::Int64, ν::Int64, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true, rSH = false) = equivariant_model(degord2spec(radial; totaldegree = totdeg, order = ν, Lmax=L, islong = islong)[2], radial, L; categories, d, group, islong, rSH) From 42ad3a927f0ac952b37fa5aaba6238b1b65ba834 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Thu, 5 Oct 2023 14:27:35 -0700 Subject: [PATCH 17/20] clean up --- src/ConstLinearLayer.jl | 25 +++++++------------------ src/builder.jl | 3 +-- 2 files changed, 8 insertions(+), 20 deletions(-) diff --git a/src/ConstLinearLayer.jl b/src/ConstLinearLayer.jl index 5dcbb0d..d03ccd6 100644 --- a/src/ConstLinearLayer.jl +++ b/src/ConstLinearLayer.jl @@ -2,15 +2,11 @@ import ChainRulesCore: rrule using LuxCore using LuxCore: AbstractExplicitLayer -struct ConstLinearLayer <: AbstractExplicitLayer # where {in_dim,out_dim,T} - op #::AbstractMatrix{T} - position::Union{Vector{Int64}, UnitRange{Int64}} +struct ConstLinearLayer <: AbstractExplicitLayer + op end -ConstLinearLayer(op) = ConstLinearLayer(op,1:size(op,2)) -# ConstLinearLayer(op, pos::Union{Vector{Int64}, UnitRange{Int64}}) = ConstLinearLayer(op,pos) - -(l::ConstLinearLayer)(x::AbstractVector) = l.op * x[l.position] +(l::ConstLinearLayer)(x::AbstractVector) = l.op * x[1:size(l.op,2)] (l::ConstLinearLayer)(x::AbstractMatrix) = begin Tmp = l(x[1,:]) @@ -20,6 +16,9 @@ ConstLinearLayer(op) = ConstLinearLayer(op,1:size(op,2)) return Tmp' end + (l::ConstLinearLayer)(x::AbstractArray,ps,st) = (l(x), st) + + # NOTE: the following rrule is kept because there is a issue with SparseArray function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractVector) val = l(x) function pb(A) @@ -28,20 +27,10 @@ function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractVector) return val, pb end -(l::ConstLinearLayer)(x::AbstractArray,ps,st) = (l(x), st) - function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractArray,ps,st) val = l(x,ps,st) function pb(A) return NoTangent(), NoTangent(), l.op' * A[1], (op = A[1] * x',), NoTangent() end return val, pb -end - -# function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractMatrix, ps, st) -# val = l(x, ps, st) -# function pb(A) -# return NoTangent(), NoTangent(), l.op' * A[1], (op = A[1] * x',), NoTangent() -# end -# return val, pb -# end \ No newline at end of file +end \ No newline at end of file diff --git a/src/builder.jl b/src/builder.jl index 2bbab0c..c215411 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -261,7 +261,6 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories= C = _rpi_A2B_matrix(cgen, spec_nlm) end - #TODO:make use [ C[i], pos[i] ] to generate another sparse matrix so that... l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(new_sparse_matrix(C[i],pos[i])) for i = 1:L+1]... ) : ConstLinearLayer(C) # C - A2Bmap luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB) @@ -278,7 +277,7 @@ function new_sparse_matrix(C,pos) for i = 1:size(C,1) C_new[i,pos] = C[i,:] end - return C_new + return sparse(C_new) end # more constructors equivariant_model From 8e63f32c3acb4426c76e5b0d07fd097b4ef9041d Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Thu, 5 Oct 2023 16:03:20 -0700 Subject: [PATCH 18/20] Minor revision --- src/ConstLinearLayer.jl | 8 ++++---- src/builder.jl | 11 +---------- src/utils.jl | 8 ++++++++ 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/ConstLinearLayer.jl b/src/ConstLinearLayer.jl index d03ccd6..d2a9f7f 100644 --- a/src/ConstLinearLayer.jl +++ b/src/ConstLinearLayer.jl @@ -2,13 +2,13 @@ import ChainRulesCore: rrule using LuxCore using LuxCore: AbstractExplicitLayer -struct ConstLinearLayer <: AbstractExplicitLayer - op +struct ConstLinearLayer{T} <: AbstractExplicitLayer + op::T end -(l::ConstLinearLayer)(x::AbstractVector) = l.op * x[1:size(l.op,2)] +(l::ConstLinearLayer{T})(x::AbstractVector) where T = l.op * x -(l::ConstLinearLayer)(x::AbstractMatrix) = begin +(l::ConstLinearLayer{T})(x::AbstractMatrix) where T = begin Tmp = l(x[1,:]) for i = 2:size(x,1) Tmp = [Tmp l(x[i,:])] diff --git a/src/builder.jl b/src/builder.jl index c215411..e615144 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -261,7 +261,7 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories= C = _rpi_A2B_matrix(cgen, spec_nlm) end - l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(new_sparse_matrix(C[i],pos[i])) for i = 1:L+1]... ) : ConstLinearLayer(C) + l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(new_sparse_matrix(C[i],pos[i],length(spec_nlm))) for i = 1:L+1]... ) : ConstLinearLayer(C) # C - A2Bmap luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB) # luxchain = Chain(xx2AA = luxchain_tmp, BB = l_sym) @@ -271,15 +271,6 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories= return luxchain, ps, st end -function new_sparse_matrix(C,pos) - col = maximum(pos) - C_new = sparse(zeros(typeof(C[1]),size(C,1),col)) - for i = 1:size(C,1) - C_new[i,pos] = C[i,:] - end - return sparse(C_new) -end - # more constructors equivariant_model equivariant_model(totdeg::Int64, ν::Int64, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true, rSH = false) = equivariant_model(degord2spec(radial; totaldegree = totdeg, order = ν, Lmax=L, islong = islong)[2], radial, L; categories, d, group, islong, rSH) diff --git a/src/utils.jl b/src/utils.jl index d319690..68cf47c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -253,3 +253,11 @@ function degord2spec(radial::Radial_basis; totaldegree, order, Lmax, catagories end get_i(i) = WrappedFunction(t -> t[i]) + +function new_sparse_matrix(C,pos,len) + C_new = sparse(zeros(typeof(C[1]),size(C,1),len)) + for i = 1:size(C,1) + C_new[i,pos] = C[i,:] + end + return sparse(C_new) +end \ No newline at end of file From bdffa55a74603a7c7f8a4f35e5a72e547b53eaa7 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Thu, 5 Oct 2023 23:09:53 -0700 Subject: [PATCH 19/20] faster linear transformation construction --- src/builder.jl | 3 ++- src/utils.jl | 11 ++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/builder.jl b/src/builder.jl index e615144..5cab21b 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -261,7 +261,8 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories= C = _rpi_A2B_matrix(cgen, spec_nlm) end - l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(new_sparse_matrix(C[i],pos[i],length(spec_nlm))) for i = 1:L+1]... ) : ConstLinearLayer(C) + # l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(new_sparse_matrix(C[i],pos[i],length(spec_nlm))) for i = 1:L+1]... ) : ConstLinearLayer(C) + l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(C[i]*sparse_trans(pos[i],length(spec_nlm))) for i = 1:L+1]... ) : ConstLinearLayer(C) # C - A2Bmap luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB) # luxchain = Chain(xx2AA = luxchain_tmp, BB = l_sym) diff --git a/src/utils.jl b/src/utils.jl index 68cf47c..79bfe95 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -254,10 +254,11 @@ end get_i(i) = WrappedFunction(t -> t[i]) -function new_sparse_matrix(C,pos,len) - C_new = sparse(zeros(typeof(C[1]),size(C,1),len)) - for i = 1:size(C,1) - C_new[i,pos] = C[i,:] +function sparse_trans(pos,len::Int64) + @assert maximum(pos) <= len + A = sparse(zeros(Int64, length(pos),len)) + for i = 1:length(pos) + A[i,pos[i]] = 1 end - return sparse(C_new) + return sparse(A) end \ No newline at end of file From 7223e6cba156ea6da5f8c999522188e4d377e410 Mon Sep 17 00:00:00 2001 From: zhanglw0521 Date: Fri, 13 Oct 2023 15:25:37 -0700 Subject: [PATCH 20/20] turning C * X[pos] to LO * X with LO a combined LinearOperator --- Project.toml | 1 + src/ConstLinearLayer.jl | 21 +++++++++++++++++++-- src/builder.jl | 4 +--- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 088c94e..985d165 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" diff --git a/src/ConstLinearLayer.jl b/src/ConstLinearLayer.jl index d2a9f7f..52ee758 100644 --- a/src/ConstLinearLayer.jl +++ b/src/ConstLinearLayer.jl @@ -1,5 +1,5 @@ import ChainRulesCore: rrule -using LuxCore +using LuxCore, LinearOperators using LuxCore: AbstractExplicitLayer struct ConstLinearLayer{T} <: AbstractExplicitLayer @@ -33,4 +33,21 @@ function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractArray,ps return NoTangent(), NoTangent(), l.op' * A[1], (op = A[1] * x',), NoTangent() end return val, pb -end \ No newline at end of file +end + +function _linear_operator_L(L, C, pos, len) + if L == 0 + T = ComplexF64 + fL = let C=C, idx=pos#, T=T + (res, aa) -> mul!(res, C, aa[idx]);# try; mul!(res, C, aa[idx]); catch; mul!(zeros(T,size(C,1)), C, aa[idx]); end + end + else + T = SVector{2L+1,ComplexF64} + fL = let C=C, idx=pos#, T=T + (res, aa) -> begin + res[:] .= C * aa[idx] + end + end + end + return LinearOperator{T}(size(C,1), len, false, false, fL, nothing, nothing; S = Vector{T}) + end \ No newline at end of file diff --git a/src/builder.jl b/src/builder.jl index 5cab21b..c58d543 100644 --- a/src/builder.jl +++ b/src/builder.jl @@ -260,9 +260,7 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories= cgen = rSH ? Rot3DCoeffs_real(L) : Rot3DCoeffs(L) # TODO: this should be made group related C = _rpi_A2B_matrix(cgen, spec_nlm) end - - # l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(new_sparse_matrix(C[i],pos[i],length(spec_nlm))) for i = 1:L+1]... ) : ConstLinearLayer(C) - l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(C[i]*sparse_trans(pos[i],length(spec_nlm))) for i = 1:L+1]... ) : ConstLinearLayer(C) + l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(_linear_operator_L(i-1,C[i],pos[i],length(spec_nlm))) for i = 1:L+1]... ) : ConstLinearLayer(C) # C - A2Bmap luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB) # luxchain = Chain(xx2AA = luxchain_tmp, BB = l_sym)