Skip to content

Commit

Permalink
improve the fitting and predicting parallelization
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenQianA committed May 28, 2024
1 parent ec5b7ce commit ea2c737
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 24 deletions.
5 changes: 3 additions & 2 deletions src/fitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,19 +172,20 @@ function _assemble_ls(basis::SymmetricBasis, data::T, enable_mean::Bool=false) w
np = length(procs(Avalr))
nstates = length(data.states)
nstates_pp = ceil(Int, nstates/np)
np = ceil(Int, nstates/nstates_pp)
idx_begins = [nstates_pp*(idx-1)+1 for idx in 1:np]
idx_ends = [nstates_pp*(idx) for idx in 1:(np-1)]
push!(idx_ends, nstates)
@sync begin
for (i, id) in enumerate(procs(Avalr))
for (i, id) in enumerate(procs(Avalr)[begin:np])
@async begin
@spawnat id begin
cfg = ACEConfig.(data.states[idx_begins[i]:idx_ends[i]])
Aval_ele = evaluate.(Ref(basis), cfg)
Avalr_ele = _evaluate_real.(Aval_ele)
Avalr_ele = permutedims(reduce(hcat, Avalr_ele), (2, 1))
@cast M[i,j,k,l] := Avalr_ele[i,j][k,l]
Avalr[idx_begins[i]: idx_ends[i], :, :, :] = M
Avalr[idx_begins[i]: idx_ends[i], :, :, :] .= M
end
end
end
Expand Down
207 changes: 185 additions & 22 deletions src/predicting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ module Predicting

using ACE, ACEbase, ACEhamiltonians, LinearAlgebra
using JuLIP: Atoms, neighbourlist
using ACE: ACEConfig, AbstractState, evaluate
using ACE: ACEConfig, AbstractState, SymmetricBasis, evaluate

using ACEhamiltonians.States: _get_states
using ACEhamiltonians.Fitting: _evaluate_real

using ACEhamiltonians: DUAL_BASIS_MODEL

using SharedArrays, Distributed, TensorCast

export predict, predict!, cell_translations


Expand Down Expand Up @@ -149,6 +151,139 @@ function predict!(values::AbstractMatrix, submodel::T, state::Vector{S}) where {
end


#########for parallelization

function get_discriptors(basis::SymmetricBasis, states::Vector{<:Vector{<:AbstractState}})
# This will be rewritten once the other code has been refactored.

# Should `A` not be constructed using `acquire_B!`?

n₁, n₂, type = ACE.valtype(basis).parameters[3:5]
# Currently the code desires "A" to be an X×Y matrix of Nᵢ×Nⱼ matrices, where X is
# the number of sub-block samples, Y is equal to `size(bos.basis.A2Bmap)[1]`, and
# Nᵢ×Nⱼ is the sub-block shape; i.e. 3×3 for pp interactions. This may be refactored
# at a later data if this layout is not found to be strictly necessary.
n₃ = length(states)

Avalr = SharedArray{real(type), 4}(n₃, length(basis), n₁, n₂)
np = length(procs(Avalr))
nstates = length(states)
nstates_pp = ceil(Int, nstates/np)
np = ceil(Int, nstates/nstates_pp)
idx_begins = [nstates_pp*(idx-1)+1 for idx in 1:np]
idx_ends = [nstates_pp*(idx) for idx in 1:(np-1)]
push!(idx_ends, nstates)
@sync begin
for (i, id) in enumerate(procs(Avalr)[begin:np])
@async begin
@spawnat id begin
cfg = ACEConfig.(states[idx_begins[i]:idx_ends[i]])
Aval_ele = evaluate.(Ref(basis), cfg)
Avalr_ele = _evaluate_real.(Aval_ele)
Avalr_ele = permutedims(reduce(hcat, Avalr_ele), (2, 1))
@cast M[i,j,k,l] := Avalr_ele[i,j][k,l]
Avalr[idx_begins[i]: idx_ends[i], :, :, :] .= M
end
end
end
end
@cast A[i,j][k,l] := Avalr[i,j,k,l]

return A

end


function infer(coefficients::Vector{Float64}, mean::Matrix{Float64}, B::SubArray{<:Any})
return coefficients' * B + mean
end

function infer(coefficients::Vector{Float64}, mean::Matrix{Float64}, B::Matrix{<:Any})
@cast B_broadcast[i][j] := B[i,j]
result = infer.(Ref(coefficients), Ref(mean), B_broadcast)
return result
end

function predict!(values::AbstractArray{<:Any, 3}, submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}})
# If the model has been fitted then use it to predict the results; otherwise just
# assume the results are zero.
if is_fitted(submodel)
# Construct a descriptor representing the supplied state and evaluate the
# basis on it to predict the associated sub-block.
# A = evaluate(submodel.basis, ACEConfig(state))
# B = _evaluate_real(A)
B_batch = get_discriptors(submodel.basis, states)
# coeffs_expanded = repeat(submodel.coefficients', size(B_batch, 1), 1)
# means_expanded = fill(submodel.mean, size(B_batch, 1))
# values .= dropdims(sum(coeffs_expanded .* B_batch, dims=2), dims=2) + means_expanded
values .= cat(infer(submodel.coefficients, submodel.mean, B_batch)..., dims=3)
# values .= cat([(submodel.coefficients' * B_batch[i,:]) + submodel.mean for i in 1: size(B_batch,1)]..., dims=3)
# values = (submodel.coefficients' * B) + submodel.mean

@static if DUAL_BASIS_MODEL
if typeof(submodel) <: AnisoSubModel
# A = evaluate(submodel.basis_i, ACEConfig(reflect.(state)))
# B = _evaluate_real(A)
B_batch = get_discriptors(submodel.basis_i, [reflect.(state) for state in states])
# values .= (values + ((submodel.coefficients_i' * B) + submodel.mean_i)') / 2.0
# values .= (values + cat([((submodel.coefficients_i' * B_batch[i,:]) + submodel.mean_i)' for
# i in 1: size(B_batch,1)]..., dims=3))/2.0
values .= (values + permutedims(cat(infer(submodel.coefficients_i, submodel.mean_i, B_batch)..., dims=3), (2,1,3))) / 2.0

elseif !ison(submodel) && (submodel.id[1] == submodel.id[2]) && (submodel.id[3] == submodel.id[4])
# If the dual basis model is being used then it is assumed that the symmetry
# issue has not been resolved thus an additional symmetrisation operation is
# required.
# A = evaluate(submodel.basis, ACEConfig(reflect.(state)))
# B = _evaluate_real(A)
B_batch = get_discriptors(submodel.basis, [reflect.(state) for state in states])
# values .= (values + ((submodel.coefficients' * B) + submodel.mean)') / 2.0
# values .= (values + cat([((submodel.coefficients' * B_batch[i,:]) + submodel.mean)' for
# i in 1: size(B_batch,1)]..., dims=3))/2.0
values .= (values + permutedims(cat(infer(submodel.coefficients, submodel.mean, B_batch)..., dims=3), (2,1,3))) / 2.0
end
end

else
fill!(values, 0.0)
end
end



# function predict_state(submodel::T, state::Vector{S}) where {T<:AHSubModel, S<:AbstractState}
# # If the model has been fitted then use it to predict the results; otherwise just
# # assume the results are zero.
# if is_fitted(submodel)
# # Construct a descriptor representing the supplied state and evaluate the
# # basis on it to predict the associated sub-block.
# A = evaluate(submodel.basis, ACEConfig(state))
# B = _evaluate_real(A)
# values = (submodel.coefficients' * B) + submodel.mean

# @static if DUAL_BASIS_MODEL
# if T<: AnisoSubModel
# A = evaluate(submodel.basis_i, ACEConfig(reflect.(state)))
# B = _evaluate_real(A)
# values = (values + ((submodel.coefficients_i' * B) + submodel.mean_i)') / 2.0
# elseif !ison(submodel) && (submodel.id[1] == submodel.id[2]) && (submodel.id[3] == submodel.id[4])
# # If the dual basis model is being used then it is assumed that the symmetry
# # issue has not been resolved thus an additional symmetrisation operation is
# # required.
# A = evaluate(submodel.basis, ACEConfig(reflect.(state)))
# B = _evaluate_real(A)
# values = (values + ((submodel.coefficients' * B) + submodel.mean)') / 2.0
# end
# end

# else
# values = zeros(size(submodel))
# end
# end




# Construct and fill a matrix with the results of a single state

"""
Expand Down Expand Up @@ -186,40 +321,68 @@ specified states. This is a the batch operable variant of the primary `predict!`
# end


using Distributed
function predict!(values::AbstractArray{<:Any, 3}, submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}})
@sync begin
for i=1:length(states)
@async begin
@spawn @views predict!(values[:, :, i], submodel, states[i])
end
end
end
end
# using Distributed
# function predict!(values::AbstractArray{<:Any, 3}, submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}})
# @sync begin
# for i=1:length(states)
# @async begin
# @spawn @views predict!(values[:, :, i], submodel, states[i])
# end
# end
# end
# end


"""
"""
# function predict(submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}})
# # Construct and fill a matrix with the results from multiple states
# n, m, type = ACE.valtype(submodel.basis).parameters[3:5]
# values = Array{real(type), 3}(undef, n, m, length(states))
# predict!(values, submodel, states)
# return values

# function predict!(values::AbstractArray{<:Any, 3}, submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}})
# np = length(procs(values))
# nstates = length(states)
# nstates_pp = ceil(Int, nstates/np)
# np = ceil(Int, nstates/nstates_pp)
# idx_begins = [nstates_pp*(idx-1)+1 for idx in 1:np]
# idx_ends = [nstates_pp*(idx) for idx in 1:(np-1)]
# push!(idx_ends, nstates)
# @sync begin
# for (i, id) in enumerate(procs(values)[begin:np])
# @async begin
# @spawnat id begin
# values[:, :, idx_begins[i]: idx_ends[i]] = cat(predict_state.(Ref(submodel), states[idx_begins[i]:idx_ends[i]])..., dims=3)
# end
# end
# end
# end
# end


# function predict!(values::AbstractArray{<:Any, 3}, submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}})
# for i=1:length(states)
# @views predict!(values[:, :, i], submodel, states[i])
# end
# end


using SharedArrays

"""
"""
function predict(submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}})
# Construct and fill a matrix with the results from multiple states
n, m, type = ACE.valtype(submodel.basis).parameters[3:5]
# values = Array{real(type), 3}(undef, n, m, length(states))
values = SharedArray{real(type), 3}(n, m, length(states))
values = Array{real(type), 3}(undef, n, m, length(states))
predict!(values, submodel, states)
return values
end


# function predict(submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}})
# # Construct and fill a matrix with the results from multiple states
# n, m, type = ACE.valtype(submodel.basis).parameters[3:5]
# # values = Array{real(type), 3}(undef, n, m, length(states))
# values = SharedArray{real(type), 3}(n, m, length(states))
# predict!(values, submodel, states)
# return values
# end


# Special version of the batch operable `predict!` method that is used when scattering data
# into a Vector of AbstractMatrix types rather than into a three dimensional tensor. This
# is implemented to facilitate the scattering of data into collection of sub-view arrays.
Expand Down

0 comments on commit ea2c737

Please sign in to comment.