From ecc39aadd3f8fd8578b97864f8c3d1b56c56cb73 Mon Sep 17 00:00:00 2001 From: chriselrod Date: Wed, 4 Aug 2021 10:03:30 -0400 Subject: [PATCH] In presence of mixed inner and outer reductions, place all inner reductions in inner most loop. --- src/modeling/determinestrategy.jl | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/src/modeling/determinestrategy.jl b/src/modeling/determinestrategy.jl index 219bec8d8..dc7129831 100644 --- a/src/modeling/determinestrategy.jl +++ b/src/modeling/determinestrategy.jl @@ -1182,7 +1182,28 @@ struct LoopOrders buff::Vector{Symbol} end -function LoopOrders(ls::LoopSet) +function outer_reduct_loopordersplit(ls::LoopSet) + ops = operations(ls) + nonouterreducts = Int[] + for i ∈ eachindex(ops) + i ∈ ls.outer_reductions || push!(nonouterreducts, i) + end + reductsyms = Symbol[] + nonreductsyms = Symbol[] + for l ∈ ls.loopsymbols + isreduct = false + for opid ∈ nonouterreducts + if l ∈ reduceddependencies(ops[opid]) + isreduct = true + push!(reductsyms, l) + break + end + end + isreduct || push!(nonreductsyms, l) + end + reductsyms, nonreductsyms +end +function loopordersplit(ls::LoopSet) reductsyms = Symbol[] nonreductsyms = Symbol[] for l ∈ ls.loopsymbols @@ -1196,6 +1217,14 @@ function LoopOrders(ls::LoopSet) end isreduct || push!(nonreductsyms, l) end + reductsyms, nonreductsyms +end +function LoopOrders(ls::LoopSet) + if length(ls.outer_reductions) == 0 + reductsyms, nonreductsyms = loopordersplit(ls) + else + reductsyms, nonreductsyms = outer_reduct_loopordersplit(ls) + end LoopOrders(nonreductsyms, reductsyms, Vector{Symbol}(undef, length(ls.loopsymbols))) end