Skip to content

Commit

Permalink
predict parallelization
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenQianA committed May 20, 2024
1 parent 448f913 commit 69892d4
Showing 1 changed file with 135 additions and 5 deletions.
140 changes: 135 additions & 5 deletions src/predicting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,19 +171,50 @@ Predict the values for a collection of sub-blocks by evaluating the provided bas
specified states. This is a the batch operable variant of the primary `predict!` method.
"""
# function predict!(values::AbstractArray{<:Any, 3}, submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}})
# for i=1:length(states)
# @views predict!(values[:, :, i], submodel, states[i])
# end
# end


# using Base.Threads
# function predict!(values::AbstractArray{<:Any, 3}, submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}})
# @threads for i=1:length(states)
# @views predict!(values[:, :, i], submodel, states[i])
# end
# end


using Distributed
function predict!(values::AbstractArray{<:Any, 3}, submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}})
for i=1:length(states)
@views predict!(values[:, :, i], submodel, states[i])
@sync begin
for i=1:length(states)
@async begin
fetch(@spawn @views predict!(values[:, :, i], submodel, states[i]))
end
end
end
end


"""
"""
# function predict(submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}})
# # Construct and fill a matrix with the results from multiple states
# n, m, type = ACE.valtype(submodel.basis).parameters[3:5]
# values = Array{real(type), 3}(undef, n, m, length(states))
# predict!(values, submodel, states)
# return values
# end


using SharedArrays
function predict(submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}})
# Construct and fill a matrix with the results from multiple states
n, m, type = ACE.valtype(submodel.basis).parameters[3:5]
values = Array{real(type), 3}(undef, n, m, length(states))
# values = Array{real(type), 3}(undef, n, m, length(states))
A = SharedArray{real(type), 3}(n, m, length(states))
predict!(values, submodel, states)
return values
end
Expand All @@ -198,8 +229,12 @@ function predict!(values::Vector{<:Any}, submodel::AHSubModel, states::Vector{<:
end
end



# using Base.Threads
# function predict!(values::Vector{<:Any}, submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}})
# @threads for i=1:length(states)
# @views predict!(values[i], submodel, states[i])
# end
# end

"""
"""
Expand All @@ -214,6 +249,7 @@ function predict(model::Model, atoms::Atoms, cell_indices::Union{Nothing, Abstra
end
end


function _predict(model, atoms, cell_indices)

# Todo:-
Expand Down Expand Up @@ -308,6 +344,100 @@ function _predict(model, atoms, cell_indices)
end


# function _predict(model, atoms, cell_indices)

# # Todo:-
# # - use symmetry to prevent having to compute data for cells reflected
# # cell pairs; i.e. [ 0, 0, 1] & [ 0, 0, -1]
# # - Setting the on-sites to an identity should be determined by the model
# # rather than just assuming that the user always wants on-site overlap
# # blocks to be identity matrices.

# basis_def = model.basis_definition
# n_orbs = number_of_orbitals(atoms, basis_def)

# # Matrix into which the final results will be placed
# matrix = zeros(n_orbs, n_orbs, size(cell_indices, 2))

# # Mirror index map array required by `_reflect_block_idxs!`
# mirror_idxs = _mirror_idxs(cell_indices)

# # The on-site blocks of overlap matrices are approximated as identity matrix.
# if model.label ≡ "S"
# matrix[1:n_orbs+1:n_orbs^2] .= 1.0
# end

# for (species₁, species₂) in species_pairs(atoms::Atoms)

# # Matrix containing the block indices of all species₁-species₂ atom-blocks
# blockᵢ = repeat_atomic_block_idxs(
# atomic_block_idxs(species₁, species₂, atoms), size(cell_indices, 2))

# # Identify on-site sub-blocks now as they as static over the shell pair loop.
# # Note that when `species₁≠species₂` `length(on_blockᵢ)≡0`.
# on_blockᵢ = filter_on_site_idxs(blockᵢ)

# Threads.@threads for (shellᵢ, shellⱼ) in shell_pairs(species₁, species₂, basis_def)


# # Get the off-site basis associated with this interaction
# basis_off = model.off_site_submodels[(species₁, species₂, shellᵢ, shellⱼ)]

# # Identify off-site sub-blocks with bond-distances less than the specified cutoff
# off_blockᵢ = filter_idxs_by_bond_distance(
# filter_off_site_idxs(blockᵢ),
# envelope(basis_off).r0cut, atoms, cell_indices)

# # Blocks in the lower triangle are redundant in the homo-orbital interactions
# if species₁ ≡ species₂ && shellᵢ ≡ shellⱼ
# off_blockᵢ = filter_upper_idxs(off_blockᵢ)
# end

# off_site_states = _get_states( # Build states for the off-site atom-blocks
# off_blockᵢ, atoms, envelope(basis_off), cell_indices)

# # Don't try to compute off-site interactions if none exist
# if length(off_site_states) > 0
# let values = predict(basis_off, off_site_states) # Predict off-site sub-blocks
# set_sub_blocks!( # Assign off-site sub-blocks to the matrix
# matrix, values, off_blockᵢ, shellᵢ, shellⱼ, atoms, basis_def)


# _reflect_block_idxs!(off_blockᵢ, mirror_idxs)
# values = permutedims(values, (2, 1, 3))
# set_sub_blocks!( # Assign data to symmetrically equivalent blocks
# matrix, values, off_blockᵢ, shellⱼ, shellᵢ, atoms, basis_def)
# end
# end


# # Evaluate on-site terms for homo-atomic interactions; but only if not instructed
# # to approximate the on-site sub-blocks as identify matrices.
# if species₁ ≡ species₂ && model.label ≠ "S"
# # Get the on-site basis and construct the on-site states
# basis_on = model.on_site_submodels[(species₁, shellᵢ, shellⱼ)]
# on_site_states = _get_states(on_blockᵢ, atoms; r=radial(basis_on).R.ru)

# # Don't try to compute on-site interactions if none exist
# if length(on_site_states) > 0
# let values = predict(basis_on, on_site_states) # Predict on-site sub-blocks
# set_sub_blocks!( # Assign on-site sub-blocks to the matrix
# matrix, values, on_blockᵢ, shellᵢ, shellⱼ, atoms, basis_def)

# values = permutedims(values, (2, 1, 3))
# set_sub_blocks!( # Assign data to the symmetrically equivalent blocks
# matrix, values, on_blockᵢ, shellⱼ, shellᵢ, atoms, basis_def)
# end
# end
# end

# end
# end

# return matrix
# end


function _predict(model, atoms)
# Currently this method has the tendency to produce non-positive definite overlap
# matrices when working with the aluminum systems, however this is not observed in
Expand Down

0 comments on commit 69892d4

Please sign in to comment.