From c92ed926aca5ae5b39f659289d7a383c7a75cd63 Mon Sep 17 00:00:00 2001 From: ChenQianA <11525069@zju.edu.cn> Date: Fri, 24 May 2024 20:43:14 +0100 Subject: [PATCH] improve the predict parallelization --- src/predicting.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/predicting.jl b/src/predicting.jl index cce846f..36d944f 100644 --- a/src/predicting.jl +++ b/src/predicting.jl @@ -191,7 +191,7 @@ function predict!(values::AbstractArray{<:Any, 3}, submodel::AHSubModel, states: @sync begin for i=1:length(states) @async begin - fetch(@spawn @views predict!(values[:, :, i], submodel, states[i])) + @spawn @views predict!(values[:, :, i], submodel, states[i]) end end end @@ -214,7 +214,7 @@ function predict(submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}} # Construct and fill a matrix with the results from multiple states n, m, type = ACE.valtype(submodel.basis).parameters[3:5] # values = Array{real(type), 3}(undef, n, m, length(states)) - A = SharedArray{real(type), 3}(n, m, length(states)) + values = SharedArray{real(type), 3}(n, m, length(states)) predict!(values, submodel, states) return values end