Skip to content

Commit

Permalink
Merge pull request #31 from ACEsuit/co/p4ml
Browse files Browse the repository at this point in the history
Update P4ML to 0.3.x
  • Loading branch information
cortner authored Jun 22, 2024
2 parents a7402d7 + 724885d commit 8d6c62d
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 44 deletions.
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EquivariantModels"
uuid = "73ee3e68-46fd-466f-9c56-451dc0291ebc"
authors = ["Christoph Ortner <christohortner@gmail.com> and contributors"]
version = "0.0.3"
authors = ["Christoph Ortner <christophortner@gmail.com> and contributors"]
version = "0.0.4"

[deps]
ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e"
Expand All @@ -13,7 +13,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
ObjectPools = "658cac36-ff0f-48ad-967c-110375d98c9d"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Polynomials4ML = "03c4bcba-a943-47e9-bfa1-b1661fc2974f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -27,6 +26,8 @@ WignerD = "87c4ff3e-34df-11e9-37a7-516cea4e0402"

[compat]
julia = "1"
DecoratedParticles = "0.0.5"
Polynomials4ML = "0.3.1"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
15 changes: 9 additions & 6 deletions src/ConstLinearLayer.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using LuxCore, LinearOperators
using LuxCore: AbstractExplicitLayer
using ObjectPools: unwrap, release!
using SparseArrays: AbstractSparseMatrixCSC, nonzeros, rowvals, nzrange
using LinearAlgebra: Adjoint

Expand All @@ -12,25 +11,29 @@ end

# === evaluation interface ===
_valtype(op::AbstractMatrix{<: Number}, x) = promote_type(eltype(op), eltype(x))
_valtype(op::AbstractMatrix{<:Number}, x::AbstractVector{<:Number}) = promote_type(eltype(op), eltype(x))
_valtype(op::AbstractMatrix{<:Number}, x::AbstractVector{<:StaticVector}) = promote_type(eltype(op), eltype(x))

_valtype(op::AbstractMatrix{<: AbstractVector}, x::AbstractArray{<: Number}) = SVector{length(op[1]), promote_type(eltype(op[1]), eltype(x[1][1]))}
_valtype(op::AbstractMatrix{<: AbstractVector}, x::AbstractArray{<: AbstractVector}) = promote_type(eltype(op[1]), eltype(x[1][1]))
_valtype(op::AbstractMatrix{<: AbstractVector}, x::AbstractVector{<: Number}) = SVector{length(op[1]), promote_type(eltype(op[1]), eltype(x))}
_valtype(op::AbstractMatrix{<: AbstractVector}, x::AbstractArray{<: AbstractVector}) = promote_type(eltype(op[1]), eltype(x[1]))
_valtype(op::AbstractMatrix{<:AbstractVector}, x::AbstractVector{<: StaticVector}) = promote_type(eltype(op[1]), eltype(x[1]))


(l::ConstLinearLayer)(x::AbstractArray, ps, st) = (l(x), st)

# sparse linear op interface
(l::ConstLinearLayer{<: AbstractSparseMatrixCSC})(x::AbstractVector) = begin
TT =_valtype(l.op, x)
out = zeros(TT, size(l.op, 1))
genmul!(out, l.op, unwrap(x), *)
release!(x)
genmul!(out, l.op, x, *)
return out
end

(l::ConstLinearLayer{<: AbstractSparseMatrixCSC})(x::AbstractMatrix) = begin
TT = _valtype(l.op, x)
out = zeros(TT, (size(l.op, 1), size(x, 2)))
genmul!(out, l.op, unwrap(x), *)
release!(x)
genmul!(out, l.op, x, *)
return out
end

Expand Down
8 changes: 4 additions & 4 deletions src/builder.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using LinearAlgebra
using SparseArrays: SparseMatrixCSC, sparse
using RepLieGroups.O3: Rot3DCoeffs, Rot3DCoeffs_real, Rot3DCoeffs_long, re_basis, SYYVector, mm_filter, coco_dot
using Polynomials4ML: legendre_basis, RYlmBasis, natural_indices, degree
using Polynomials4ML: legendre_basis, real_sphericalharmonics, complex_sphericalharmonics, natural_indices, degree
using Polynomials4ML.Utils: gensparse
using Lux: WrappedFunction
using Lux
Expand Down Expand Up @@ -171,7 +171,7 @@ function xx2AA(spec_nlm, radial::Radial_basis; categories=[], _get_cat = _get_ca
@assert issubset(nset(spec1p), radial.Radialspec) || issubset(nlset(spec1p), radial.Radialspec)

dict_spec1p = Dict([spec1p[i] => i for i = 1:length(spec1p)])
Ylm = rSH ? RYlmBasis(lmax) : CYlmBasis(lmax)
Ylm = rSH ? real_sphericalharmonics(lmax) : complex_sphericalharmonics(lmax)
# Rn = radial_basis(nmax)

if !isempty(categories)
Expand Down Expand Up @@ -342,7 +342,7 @@ function equivariant_luxchain_constructor(totdeg, ν, L; wL = 1, Rn = legendre_b
filter = RPE_filter_long(L)
cgen = Rot3DCoeffs_long(L)

Ylm = CYlmBasis(totdeg)
Ylm = complex_sphericalharmonics(totdeg)

spec1p = make_nlms_spec(simple_radial_basis(Rn), Ylm; totaldegree = totdeg, admissible = (br, by) -> br.n + wL * by.l <= totdeg)
spec1p = sort(spec1p, by = (x -> x.n + x.l * wL))
Expand Down Expand Up @@ -400,7 +400,7 @@ end
# What can be adjusted in its input are: (1) total polynomial degree; (2) correlation order; (3) largest L
# (4) weight of the order of spherical harmonics; (5) specified radial basis
function equivariant_luxchain_constructor_new(totdeg, ν, L; wL = 1, Rn = legendre_basis(totdeg))
Ylm = CYlmBasis(totdeg)
Ylm = complex_sphericalharmonics(totdeg)

spec1p = make_nlms_spec(simple_radial_basis(Rn), Ylm; totaldegree = totdeg, admissible = (br, by) -> br.n + wL * by.l <= totdeg)
spec1p = sort(spec1p, by = (x -> x.n + x.l * wL))
Expand Down
18 changes: 7 additions & 11 deletions src/categorical.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
using Polynomials4ML
using Polynomials4ML: AbstractPoly4MLBasis, SphericalCoords
using StaticArrays: SVector, StaticArray

import Polynomials4ML: _valtype, _out_size, _outsym, evaluate, evaluate!
import Polynomials4ML: _valtype, _out_size, _outsym, evaluate, evaluate!, AbstractP4MLBasis
import ChainRulesCore: rrule, NoTangent

export CategoricalBasis
Expand Down Expand Up @@ -69,20 +68,20 @@ variable ``u`` may take. Suppose, e.g., we allow the values `[:a, :b, :c]`,
then
```julia
P = CategoricalBasis([:a, :b, :c]; varsym = :u, idxsym = :q)
evaluate(P, State(u = :a)) # Bool[1, 0, 0]
evaluate(P, State(u = :b)) # Bool[0, 1, 0]
evaluate(P, State(u = :c)) # Bool[0, 0, 1]
evaluate(P, PState(u = :a)) # Bool[1, 0, 0]
evaluate(P, PState(u = :b)) # Bool[0, 1, 0]
evaluate(P, PState(u = :c)) # Bool[0, 0, 1]
```
If we evaluate it with an unknown state we get an error:
```julia
evaluate(P, State(u = :x))
evaluate(P, PState(u = :x))
# Error : val = x not found in this list
```
Warning : the list of categories is internally stored as an SVector
which means that lookup scales linearly with the number of categories
"""
struct CategoricalBasis{LEN, T} <: AbstractPoly4MLBasis
struct CategoricalBasis{LEN, T} <: AbstractP4MLBasis
categories::SList{LEN, T}
meta::Dict{String, Any}
end
Expand All @@ -93,10 +92,7 @@ CategoricalBasis(categories::AbstractArray, meta = Dict{String, Any}() ) =
CategoricalBasis(SList(categories), meta)


#Polynomials4ML._outsym(x::Char) = :char
_outsym(x::T) where T = :T

const NSS = Union{Number, SphericalCoords, StaticArray}
const NSS = Union{Number, StaticArray}

_out_size(basis::CategoricalBasis{LEN, T}, x::T) where {LEN, T} = (LEN,)
_out_size(basis::CategoricalBasis{LEN, T}, x::Vector{T}) where {LEN, T} = (length(x), LEN)
Expand Down
4 changes: 2 additions & 2 deletions src/radial.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Polynomials4ML: natural_indices, ScalarPoly4MLBasis, lux
using Polynomials4ML: natural_indices, lux
using LuxCore: AbstractExplicitContainerLayer, AbstractExplicitLayer
export Radial_basis

Expand All @@ -23,7 +23,7 @@ Radial_basis(Rnl::AbstractExplicitLayer) =
end

# it is in its current form just for the purpose of testing - a more specific example can be found in forces.jl
function simple_radial_basis(basis::ScalarPoly4MLBasis,f_cut::Function=r->1,f_trans::Function=r->r; spec = nothing, isState = false)
function simple_radial_basis(basis,f_cut::Function=r->1,f_trans::Function=r->r; spec = nothing, isState = false)
if isnothing(spec)
try
spec = natural_indices(basis)
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ Return a list of AA specifications and A specifications
"""
function degord2spec(radial::Radial_basis; totaldegree, order, Lmax, catagories = [], filtered_extension = simple_extension, wL = 1, islong = true, rSH = false)
# Rn = radial.radial_basis(totaldegree)
Ylm = CYlmBasis(totaldegree)
Ylm = complex_sphericalharmonics(totaldegree)

spec1p = make_nlms_spec(radial, Ylm; totaldegree = totaldegree, admissible = (br, by) -> br.n + wL * by.l <= totaldegree)
spec1p = sort(spec1p, by = (x -> x.n + x.l * wL))
Expand Down
2 changes: 1 addition & 1 deletion test/test_ConstantLinearLayer.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using EquivariantModels, Lux, StaticArrays, Random, LinearAlgebra, Polynomials4ML
using Polynomials4ML: LinearLayer, RYlmBasis, lux
using Polynomials4ML: LinearLayer, lux
using EquivariantModels: degord2spec, specnlm2spec1p, xx2AA, simple_radial_basis, ConstLinearLayer
using LuxCore
using SparseArrays
Expand Down
4 changes: 2 additions & 2 deletions test/test_equiv_with_cate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ Species = [ (species[1], species[i]) for i = 1:10 ]
for ntest = 1:10
local X, θ1, θ2, θ3, Q, QX
X = [ @SVector(rand(3)) for i in 1:10 ]
XX = [State(rr = X[i], Zi = Species[i][1], Zj = Species[i][2]) for i = 1:length(X)]
XX = [PState(rr = X[i], Zi = Species[i][1], Zj = Species[i][2]) for i = 1:length(X)]
θ1 = rand() * 2pi
θ2 = rand() * 2pi
θ3 = rand() * 2pi
Q = RotXYZ(θ1, θ2, θ3)
QXX = [State(rr = Q * X[i], Zi = Species[i][1], Zj = Species[i][2]) for i = 1:length(X)]
QXX = [PState(rr = Q * X[i], Zi = Species[i][1], Zj = Species[i][2]) for i = 1:length(X)]
# QXX = [QX, Species]

print_tf(@test F(XX)[1] F(QXX)[1])
Expand Down
28 changes: 14 additions & 14 deletions test/test_equiv_with_state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ for L = 0:Lmax
Q = RotXYZ(θ1, θ2, θ3)
# Q = rand_rot()
# QX = [SVector{3}(x) for x in Ref(Q) .* X]
QX = [ State(rr = Q * X[i]) for i in 1:length(X) ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
QX = [ PState(rr = Q * X[i]) for i in 1:length(X) ]
X = [ PState(rr = X[i]) for i in 1:length(X) ]
D = wigner_D(L,Matrix(Q))'
# D = wignerD(L, θ, θ, θ)
if L == 0
Expand All @@ -48,7 +48,7 @@ for L = 0:Lmax
for _ = 1:30
local X
X = [ @SVector(rand(3)) for i in 1:10 ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
X = [ PState(rr = X[i]) for i in 1:length(X) ]
print_tf(@test F(X) F2(X))
end
println()
Expand All @@ -74,8 +74,8 @@ for ntest = 1:10
θ3 = rand() * 2pi
Q = RotXYZ(θ1, θ2, θ3)
# QX = [SVector{3}(x) for x in Ref(Q) .* X]
QX = [ State(rr = Q * X[i]) for i in 1:length(X) ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
QX = [ PState(rr = Q * X[i]) for i in 1:length(X) ]
X = [ PState(rr = X[i]) for i in 1:length(X) ]

print_tf(@test F(X)[1] F(QX)[1])

Expand Down Expand Up @@ -104,7 +104,7 @@ for l = 0:Lmax

for ntest = 1:20
X = [ @SVector(rand(3)) for i in 1:10 ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
X = [ PState(rr = X[i]) for i in 1:length(X) ]
print_tf(@test F(X)[l+1] == FF(X))
end
println()
Expand All @@ -114,7 +114,7 @@ end
for _ = 1:10
local X
X = [ @SVector(rand(3)) for i in 1:10 ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
X = [ PState(rr = X[i]) for i in 1:length(X) ]
print_tf(@test length(F(X)) == length(F2(X)) && all([F(X)[i] F2(X)[i] for i = 1:length(F(X))]))
end
println()
Expand All @@ -140,8 +140,8 @@ for L = 0:Lmax
Q = RotXYZ(0, 0, θ)
# Q = rand_rot()
# QX = [SVector{3}(x) for x in Ref(Q) .* X]
QX = [ State(rr = Q * X[i]) for i in 1:length(X) ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
QX = [ PState(rr = Q * X[i]) for i in 1:length(X) ]
X = [ PState(rr = X[i]) for i in 1:length(X) ]
D = wignerD(L, 0, 0, θ)
if length(F(X)) == 0
continue
Expand Down Expand Up @@ -177,8 +177,8 @@ for ntest = 1:20
θ3 = rand() * 2pi
Q = RotXYZ(θ1, θ2, θ3)
# QX = [SVector{3}(x) for x in Ref(Q) .* X]
QX = [ State(rr = Q * X[i]) for i in 1:length(X) ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
QX = [ PState(rr = Q * X[i]) for i in 1:length(X) ]
X = [ PState(rr = X[i]) for i in 1:length(X) ]
D = BlockDiagonal([ wigner_D(l,Matrix(Q))' for l = 0:L] )

print_tf(@test Ref(D) .* F(QX) F(X))
Expand All @@ -189,7 +189,7 @@ println()
for _ = 1:10
local X
X = [ @SVector(rand(3)) for i in 1:10 ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
X = [ PState(rr = X[i]) for i in 1:length(X) ]
print_tf(@test length(F(X)) == length(F2(X)) && all([F(X)[i] F2(X)[i] for i = 1:length(F(X))]))
end
println()
Expand All @@ -215,8 +215,8 @@ for ntest = 1:20
θ3 = rand() * 2pi
Q = RotXYZ(θ1, θ2, θ3)
# QX = [SVector{3}(x) for x in Ref(Q) .* X]
QX = [ State(rr = Q * X[i]) for i in 1:length(X) ]
X = [ State(rr = X[i]) for i in 1:length(X) ]
QX = [ PState(rr = Q * X[i]) for i in 1:length(X) ]
X = [ PState(rr = X[i]) for i in 1:length(X) ]
D = BlockDiagonal([ wigner_D(l,Matrix(Q))' for l = 0:L] )

print_tf(@test Ref(D) .* F(QX) F(X))
Expand Down

0 comments on commit 8d6c62d

Please sign in to comment.