Skip to content

Commit

Permalink
improve the predict parallelization
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenQianA committed May 24, 2024
1 parent 69892d4 commit c92ed92
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/predicting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ function predict!(values::AbstractArray{<:Any, 3}, submodel::AHSubModel, states:
@sync begin
for i=1:length(states)
@async begin
fetch(@spawn @views predict!(values[:, :, i], submodel, states[i]))
@spawn @views predict!(values[:, :, i], submodel, states[i])
end
end
end
Expand All @@ -214,7 +214,7 @@ function predict(submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}}
# Construct and fill a matrix with the results from multiple states
n, m, type = ACE.valtype(submodel.basis).parameters[3:5]
# values = Array{real(type), 3}(undef, n, m, length(states))
A = SharedArray{real(type), 3}(n, m, length(states))
values = SharedArray{real(type), 3}(n, m, length(states))
predict!(values, submodel, states)
return values
end
Expand Down

0 comments on commit c92ed92

Please sign in to comment.