Skip to content

Commit

Permalink
Merge pull request #18 from ACEsuit/Const_LinearLayer
Browse files Browse the repository at this point in the history
Constant Linear Layer
  • Loading branch information
cortner authored Oct 20, 2023
2 parents 7f032f4 + 7223e6c commit a968d5f
Show file tree
Hide file tree
Showing 8 changed files with 364 additions and 11 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ version = "0.0.2"
ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Expand Down
53 changes: 53 additions & 0 deletions src/ConstLinearLayer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import ChainRulesCore: rrule
using LuxCore, LinearOperators
using LuxCore: AbstractExplicitLayer

struct ConstLinearLayer{T} <: AbstractExplicitLayer
op::T
end

(l::ConstLinearLayer{T})(x::AbstractVector) where T = l.op * x

(l::ConstLinearLayer{T})(x::AbstractMatrix) where T = begin
Tmp = l(x[1,:])
for i = 2:size(x,1)
Tmp = [Tmp l(x[i,:])]
end
return Tmp'
end

(l::ConstLinearLayer)(x::AbstractArray,ps,st) = (l(x), st)

# NOTE: the following rrule is kept because there is a issue with SparseArray
function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractVector)
val = l(x)
function pb(A)
return NoTangent(), NoTangent(), l.op' * A[1], (op = A[1] * x',), NoTangent()
end
return val, pb
end

function rrule(::typeof(LuxCore.apply), l::ConstLinearLayer, x::AbstractArray,ps,st)
val = l(x,ps,st)
function pb(A)
return NoTangent(), NoTangent(), l.op' * A[1], (op = A[1] * x',), NoTangent()
end
return val, pb
end

function _linear_operator_L(L, C, pos, len)
if L == 0
T = ComplexF64
fL = let C=C, idx=pos#, T=T
(res, aa) -> mul!(res, C, aa[idx]);# try; mul!(res, C, aa[idx]); catch; mul!(zeros(T,size(C,1)), C, aa[idx]); end
end
else
T = SVector{2L+1,ComplexF64}
fL = let C=C, idx=pos#, T=T
(res, aa) -> begin
res[:] .= C * aa[idx]
end
end
end
return LinearOperator{T}(size(C,1), len, false, false, fL, nothing, nothing; S = Vector{T})
end
44 changes: 37 additions & 7 deletions src/builder.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
using LinearAlgebra
using SparseArrays: SparseMatrixCSC, sparse
using RepLieGroups.O3: Rot3DCoeffs, Rot3DCoeffs_real, Rot3DCoeffs_long, re_basis, SYYVector, mm_filter
using RepLieGroups.O3: Rot3DCoeffs, Rot3DCoeffs_real, Rot3DCoeffs_long, re_basis, SYYVector, mm_filter, coco_dot
using Polynomials4ML: legendre_basis, RYlmBasis, natural_indices, degree
using Polynomials4ML.Utils: gensparse
using Lux: WrappedFunction
using Lux
using Random
using Polynomials4ML
using StaticArrays
using Combinatorics

export equivariant_model, equivariant_SYY_model, equivariant_luxchain_constructor, equivariant_luxchain_constructor_new

Expand All @@ -23,6 +24,33 @@ _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T},Rot3DCoeffs_real{L,T},Rot3DCoeffs_l
spec::Vector{Vector{NamedTuple}})
Return a sparse matrix for symmetrisation of AA basis of spec with equivariance specified by cgen
"""
function rpe_basis(A::Union{Rot3DCoeffs,Rot3DCoeffs_long,Rot3DCoeffs_real}, nn::SVector{N, TN}, ll::SVector{N, Int}) where {N, TN}
Ure, Mre = re_basis(A, ll)
G = _gramian(nn, ll, Ure, Mre)
S = svd(G)
rk = rank(Diagonal(S.S); rtol = 1e-7)
Urpe = S.U[:, 1:rk]'
return Diagonal(sqrt.(S.S[1:rk])) * Urpe * Ure, Mre
end


function _gramian(nn, ll, Ure, Mre)
N = length(nn)
nre = size(Ure, 1)
G = zeros(Complex{Float64}, nre, nre)
for σ in permutations(1:N)
if (nn[σ] != nn) || (ll[σ] != ll); continue; end
for (iU1, mm1) in enumerate(Mre), (iU2, mm2) in enumerate(Mre)
if mm1[σ] == mm2
for i1 = 1:nre, i2 = 1:nre
G[i1, i2] += coco_dot(Ure[i1, iU1], Ure[i2, iU2])
end
end
end
end
return G
end

function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Rot3DCoeffs_long{L,T}}, spec) where {L,T}
# allocate triplet format
Irow, Jcol = Int[], Int[]
Expand Down Expand Up @@ -55,7 +83,7 @@ function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Ro
if (nn,ll,ss) in nnllset; continue; end

# get the Mll indices and coeffs
U, Mll = re_basis(cgen, ll)
U, Mll = rpe_basis(cgen, nn, ll)
# conver the Mlls into basis functions (NamedTuples)

rpibs = [_nlms2b(nn, ll, mm, ss) for mm in Mll]
Expand Down Expand Up @@ -83,7 +111,8 @@ function _rpi_A2B_matrix(cgen::Union{Rot3DCoeffs{L,T}, Rot3DCoeffs_real{L,T}, Ro
if (nn,ll) in nnllset; continue; end

# get the Mll indices and coeffs
U, Mll = re_basis(cgen, ll)
# U, Mll = re_basis(cgen, ll)
U, Mll = rpe_basis(cgen, nn, ll)
# conver the Mlls into basis functions (NamedTuples)

rpibs = [_nlms2b(nn, ll, mm) for mm in Mll]
Expand Down Expand Up @@ -196,6 +225,9 @@ L : Largest equivariance level
categories : A list of categories
radial_basis : specified radial basis, default using P4ML.legendre_basis
"""

include("ConstLinearLayer.jl")

function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=[], d=3, group="O3", islong=true, rSH = false)
if rSH && L > 0
error("rSH is only implemented (for now) for L = 0")
Expand Down Expand Up @@ -228,9 +260,7 @@ function equivariant_model(spec_nlm, radial::Radial_basis, L::Int64; categories=
cgen = rSH ? Rot3DCoeffs_real(L) : Rot3DCoeffs(L) # TODO: this should be made group related
C = _rpi_A2B_matrix(cgen, spec_nlm)
end

l_sym = islong ? Lux.Parallel(nothing, [WrappedFunction(x -> C[i] * x[pos[i]]) for i = 1:L+1]... ) : WrappedFunction(x -> C * x)
# TODO: make it a Const_LinearLayer instead
l_sym = islong ? Lux.Parallel(nothing, [ConstLinearLayer(_linear_operator_L(i-1,C[i],pos[i],length(spec_nlm))) for i = 1:L+1]... ) : ConstLinearLayer(C)
# C - A2Bmap
luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB)
# luxchain = Chain(xx2AA = luxchain_tmp, BB = l_sym)
Expand Down Expand Up @@ -271,7 +301,7 @@ function equivariant_SYY_model(spec_nlm, radial::Radial_basis, L::Int64; categor

cgen = Rot3DCoeffs_long(L) # TODO: this should be made group related
C = _rpi_A2B_matrix(cgen, spec_nlm)
l_sym = WrappedFunction(x -> C * x)
l_sym = ConstLinearLayer(C)

# C - A2Bmap
luxchain = append_layer(luxchain_tmp, l_sym; l_name = :BB)
Expand Down
6 changes: 2 additions & 4 deletions src/radial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,14 @@ Radial_basis(Rnl::AbstractExplicitLayer) =
end

# it is in its current form just for the purpose of testing - a more specific example can be found in forces.jl
function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=r->1,f_trans::Function=r->1; spec = nothing)
function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=r->1,f_trans::Function=r->r; spec = nothing)
if isnothing(spec)
try
spec = natural_indices(basis)
catch
error("The specification of this Radial_basis should be given explicitly!")
end
end

f(r) = f_trans(r) * f_cut(r)

return Radial_basis(Chain(getnorm = WrappedFunction(x -> norm.(x)), trans = WrappedFunction(x -> f.(x)), poly = lux(basis), ), spec)
return Radial_basis(Chain(trans = WrappedFunction(x -> f_trans.(norm.(x))), evaluation = Lux.BranchLayer(poly = lux(basis), cutoff = WrappedFunction(x -> f_cut.(x))), env = WrappedFunction(x -> x[1].*x[2]), ), spec)
end
9 changes: 9 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,12 @@ function degord2spec(radial::Radial_basis; totaldegree, order, Lmax, catagories
end

get_i(i) = WrappedFunction(t -> t[i])

function sparse_trans(pos,len::Int64)
@assert maximum(pos) <= len
A = sparse(zeros(Int64, length(pos),len))
for i = 1:length(pos)
A[i,pos[i]] = 1
end
return sparse(A)
end
Loading

0 comments on commit a968d5f

Please sign in to comment.