Skip to content

Commit

Permalink
Make the tests more rigorous
Browse files Browse the repository at this point in the history
NOTE: To make our current tests more rigorous, I hiked some functions in ACE.Wigner
  • Loading branch information
zhanglw0521 committed Sep 19, 2023
1 parent 07b1af6 commit 1d3f27d
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 28 deletions.
4 changes: 2 additions & 2 deletions src/builder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,8 @@ function cgmatrix(L1,L2)
for l = 0:L1+L2
if abs(ν)<=l
position = ν + l + 1
# cgm[i, l^2+position] = (-1)^q * cg(L1,p,L2,q,l,ν)
cgm[i, l^2+position] = cg(L1,p,L2,q,l,ν)
cgm[i, l^2+position] = (-1)^q * cg(L1,p,L2,q,l,ν)
# cgm[i, l^2+position] = cg(L1,p,L2,q,l,ν)
# cgm[i, l^2+position] = (-1)^q * sqrt( (2L1+1) * (2L2+1) ) / 2 / sqrt(π * (2l+1)) * cg(L1,0,L2,0,l,0) * cg(L1,p,L2,q,l,ν)
end
end
Expand Down
72 changes: 46 additions & 26 deletions test/test_equivariance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ using ACEbase.Testing: print_tf
using Rotations, WignerD, BlockDiagonals
using LinearAlgebra

include("wigner.jl")

@info("Testing the chain that generates a single B basis")
totdeg = 6
ν = 2
Expand All @@ -20,13 +22,16 @@ for L = 0:Lmax

@info("Tesing L = $L O(3) equivariance")
for _ = 1:30
local X, θ, Q, QX
local X, θ1, θ2, θ3, Q, QX
X = [ @SVector(rand(3)) for i in 1:10 ]
θ = rand() * 2pi
Q = RotXYZ(0, 0, θ)
θ1 = rand() * 2pi
θ2 = rand() * 2pi
θ3 = rand() * 2pi
Q = RotXYZ(θ1, θ2, θ3)
# Q = rand_rot()
QX = [SVector{3}(x) for x in Ref(Q) .* X]
D = wignerD(L, 0, 0, θ)
D = wigner_D(L,Matrix(Q))'
# D = wignerD(L, θ, θ, θ)
if L == 0
print_tf(@test F(X) F(QX))
else
Expand Down Expand Up @@ -55,16 +60,19 @@ luxchain2, ps2, st2 = equivariant_model(EquivariantModels.degord2spec_nlm(totdeg
F2(X) = luxchain(X, ps2, st2)[1]

for ntest = 1:10
local X, θ, Q, QX
local X, θ1, θ2, θ3, Q, QX
X = [ @SVector(rand(3)) for i in 1:10 ]
θ = rand() * 2pi
Q = RotXYZ(0, 0, θ)
θ1 = rand() * 2pi
θ2 = rand() * 2pi
θ3 = rand() * 2pi
Q = RotXYZ(θ1, θ2, θ3)
QX = [SVector{3}(x) for x in Ref(Q) .* X]

print_tf(@test F(X)[1] F(QX)[1])

for l = 2:L
D = wignerD(l-1, 0, 0, θ)
D = wigner_D(l-1,Matrix(Q))'
# D = wignerD(l-1, 0, 0, θ)
print_tf(@test norm.(Ref(D') .* F(X)[l] - F(QX)[l]) |> norm <1e-8)
end
end
Expand Down Expand Up @@ -145,12 +153,14 @@ F2(X) = luxchain(X, ps2, st2)[1]
@info("Tesing L = $L O(3) full equivariance")

for ntest = 1:20
local X, θ, Q, QX
local X, θ1, θ2, θ3, Q, QX
X = [ @SVector(rand(3)) for i in 1:10 ]
θ = rand() * 2pi
Q = RotXYZ(0, 0, θ)
θ1 = rand() * 2pi
θ2 = rand() * 2pi
θ3 = rand() * 2pi
Q = RotXYZ(θ1, θ2, θ3)
QX = [SVector{3}(x) for x in Ref(Q) .* X]
D = BlockDiagonal([ wignerD(l, 0, 0, θ) for l = 0:L] )
D = BlockDiagonal([ wigner_D(l,Matrix(Q))' for l = 0:L] )

print_tf(@test Ref(D) .* F(QX) F(X))
end
Expand Down Expand Up @@ -178,12 +188,14 @@ F(X) = luxchain(X, ps, st)[1]
@info("Tesing L = $L O(3) full equivariance")

for ntest = 1:20
local X, θ, Q, QX
local X, θ1, θ2, θ3, Q, QX
X = [ @SVector(rand(3)) for i in 1:10 ]
θ = rand() * 2pi
Q = RotXYZ(0, 0, θ)
θ1 = rand() * 2pi
θ2 = rand() * 2pi
θ3 = rand() * 2pi
Q = RotXYZ(θ1, θ2, θ3)
QX = [SVector{3}(x) for x in Ref(Q) .* X]
D = BlockDiagonal([ wignerD(l, 0, 0, θ) for l = 0:L] )
D = BlockDiagonal([ wigner_D(l,Matrix(Q))' for l = 0:L] )

print_tf(@test Ref(D) .* F(QX) F(X))
end
Expand All @@ -204,15 +216,19 @@ F(X) = luxchain(X, ps, st)[1]
@info("Equivariance test")
l1l2set = [(l1,l2) for l1 = 0:L for l2 = 0:L-l1]
for ntest = 1:10
local X, θ, Q, QX
local X, θ1, θ2, θ3, Q, QX
X = [ @SVector(rand(3)) for i in 1:10 ]
θ = rand() * 2pi
Q = RotXYZ(0, 0, θ)
θ1 = rand() * 2pi
θ2 = rand() * 2pi
θ3 = rand() * 2pi
Q = RotXYZ(θ1, θ2, θ3)
QX = [SVector{3}(x) for x in Ref(Q) .* X]

for (i,(l1,l2)) in enumerate(l1l2set)
D1 = wignerD(l1, 0, 0, θ)
D2 = wignerD(l2, 0, 0, θ)
D1 = wigner_D(l1,Matrix(Q))'
D2 = wigner_D(l2,Matrix(Q))'
# D1 = wignerD(l1, 0, 0, θ)
# D2 = wignerD(l2, 0, 0, θ)
if F(X)[i] |> length 0
print_tf(@test norm(Ref(D1') .* F(X)[i] .* Ref(D2) - F(QX)[i]) < 1e-8)
end
Expand All @@ -228,16 +244,20 @@ luxchain, ps, st = EquivariantModels.equivariant_luxchain_constructor_new(totdeg
F(X) = luxchain(X, ps, st)[1]

for ntest = 1:10
local X, θ, Q, QX
local X, θ1, θ2, θ3, Q, QX
X = [ @SVector(rand(3)) for i in 1:10 ]
θ = rand() * 2pi
Q = RotXYZ(0, 0, θ)
θ1 = rand() * 2pi
θ2 = rand() * 2pi
θ3 = rand() * 2pi
Q = RotXYZ(θ1, θ2, θ3)
QX = [SVector{3}(x) for x in Ref(Q) .* X]

for i = 1:length(F(X))
l1,l2 = Int.(size(F(X)[i][1]).-1)./2
D1 = wignerD(l1, 0, 0, θ)
D2 = wignerD(l2, 0, 0, θ)
D1 = wigner_D(Int(l1),Matrix(Q))'
D2 = wigner_D(Int(l2),Matrix(Q))'
# D1 = wignerD(l1, 0, 0, θ)
# D2 = wignerD(l2, 0, 0, θ)
print_tf(@test Ref(D1') .* F(X)[i] .* Ref(D2) - F(QX)[i] |> norm < 1e-12)
end
end
Expand Down
76 changes: 76 additions & 0 deletions test/wigner.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
using StaticArrays

"""
Index of entries in D matrix (sign included)
"""
struct D_Index
sign::Int64
μ::Int64
m::Int64
end

"""
auxiliary matrix - indices for D matrix
"""
wigner_D_indices(L::Integer) = ( @assert L >= 0;
[ D_Index(1, i - 1 - L, j - 1 - L) for j = 1:2*L+1, i = 1:2*L+1] )

Base.adjoint(idx::D_Index) = D_Index( (-1)^(idx.μ+idx.m), - idx.μ, - idx.m)

"""
One entry of the Wigner-big-D matrix, `[D^l]_{mu, m}`
"""
wigner_D(μ,m,l,α,β,γ) = (exp(-im*α*m) * wigner_d(m,μ,l,β) * exp(-im*γ*μ))'



"""
One entry of the Wigner-small-d matrix,
Wigner small d, modified from
```
https://github.com/cortner/SlaterKoster.jl/blob/
8dceecb073709e6448a7a219ed9d3a010fa06724/src/code_generation.jl#L73
```
"""
function wigner_d(μ, m, l, β)
fc1 = factorial(l+m)
fc2 = factorial(l-m)
fc3 = factorial(l+μ)
fc4 = factorial(l-μ)
fcm1 = sqrt(fc1 * fc2 * fc3 * fc4)

cosb = cos/ 2.0)
sinb = sin/ 2.0)

p = m - μ
low = max(0,p)
high = min(l+m,l-μ)

temp = 0.0
for s = low:high
fc5 = factorial(s)
fc6 = factorial(l+m-s)
fc7 = factorial(l-μ-s)
fc8 = factorial(s-p)
fcm2 = fc5 * fc6 * fc7 * fc8
pow1 = 2 * l - 2 * s + p
pow2 = 2 * s - p
temp += (-1)^(s+p) * cosb^pow1 * sinb^pow2 / fcm2
end
temp *= fcm1

return temp
end

mat2ang(Q) = mod(atan(Q[2,3],Q[1,3]),2pi), acos(Q[3,3]), mod(atan(Q[3,2],-Q[3,1]),2pi);


function wigner_D(L::Integer, Q::AbstractMatrix)
D = wigner_D_indices(L);
α, β, γ = mat2ang(Q);
Mat_D = [ wigner_D(D[i,j].μ, D[i,j].m, L, α, β, γ)
for i = 1:2*L+1, j = 1:2*L+1 ]
# NB: type instability here, but performance is not important.
return SMatrix{2L+1, 2L+1, ComplexF64}(Mat_D)
end

0 comments on commit 1d3f27d

Please sign in to comment.