Skip to content

Commit

Permalink
handle many args better
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Nov 20, 2022
1 parent 0841acb commit e5189a8
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 45 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LoopVectorization"
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
authors = ["Chris Elrod <[email protected]>"]
version = "0.12.140"
version = "0.12.141"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
19 changes: 7 additions & 12 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,9 @@ end
@inline ArrayInterface.device(::LowDimArray) = ArrayInterface.CPUPointer()
@generated function ArrayInterface.size(A::LowDimArray{D,T,N}) where {D,T,N}
t = Expr(:tuple)
gf = GlobalRef(Core, :getfield)
for n 1:N
if n > length(D) || D[n]
push!(t.args, Expr(:call, gf, :s, n, false))
push!(t.args, Expr(:call, getfield, :s, n))
else
push!(t.args, Expr(:call, Expr(:curly, lv(:StaticInt), 1)))
end
Expand All @@ -64,10 +63,9 @@ ArrayInterface.offsets(A::LowDimArray) = ArrayInterface.offsets(parent(A))

@generated function _lowdimfilter(::Val{D}, tup::Tuple{Vararg{Any,N}}) where {D,N}
t = Expr(:tuple)
gf = GlobalRef(Core, :getfield)
for n 1:N
if n > length(D) || D[n]
push!(t.args, Expr(:call, gf, :tup, n, false))
push!(t.args, Expr(:call, getfield, :tup, n))
end
end
Expr(:block, Expr(:meta, :inline), t)
Expand Down Expand Up @@ -178,7 +176,6 @@ function _strides_expr(@nospecialize(s), @nospecialize(x), R::Vector{Int}, D::Ve
N = length(R)
q = Expr(:block, Expr(:meta, :inline))
strd_tup = Expr(:tuple)
gf = GlobalRef(Core, :getfield)
ifel = GlobalRef(Core, :ifelse)
Nrange = 1:1:N # type stability w/ respect to reverse
use_stride_acc = true
Expand Down Expand Up @@ -207,7 +204,7 @@ function _strides_expr(@nospecialize(s), @nospecialize(x), R::Vector{Int}, D::Ve
elseif stride_acc 0
push!(strd_tup.args, staticexpr(stride_acc))
else
push!(strd_tup.args, :($gf(x, $n, false)))
push!(strd_tup.args, :($getfield(x, $n)))
end
else
if xₙ_static
Expand All @@ -217,7 +214,7 @@ function _strides_expr(@nospecialize(s), @nospecialize(x), R::Vector{Int}, D::Ve
else
push!(
strd_tup.args,
:($ifel(isone($gf(s, $n, false)), zero($xₙ_type), $gf(x, $n, false))),
:($ifel(isone($getfield(s, $n)), zero($xₙ_type), $getfield(x, $n))),
)
end
end
Expand Down Expand Up @@ -326,10 +323,9 @@ function add_broadcast!(
Klen = gensym!(ls, "K")
mA = gensym!(ls, "Aₘₖ")
mB = gensym!(ls, "Bₖₙ")
gf = GlobalRef(Core, :getfield)
pushprepreamble!(ls, Expr(:(=), mA, Expr(:(.), bcname, QuoteNode(:a))))
pushprepreamble!(ls, Expr(:(=), mB, Expr(:(.), bcname, QuoteNode(:b))))
pushprepreamble!(ls, Expr(:(=), Klen, Expr(:call, gf, Expr(:call, :size, mB), 1, false)))
pushprepreamble!(ls, Expr(:(=), Klen, Expr(:call, getfield, Expr(:call, :size, mB), 1)))
pushpreamble!(ls, Expr(:(=), Krange, Expr(:call, :(:), staticexpr(1), Klen)))
k = gensym!(ls, "k")
add_loop!(ls, Loop(k, 1, Klen, 1, Krange, Klen), k)
Expand Down Expand Up @@ -481,10 +477,9 @@ function add_broadcast!(
parents = Operation[]
deps = Symbol[]
# reduceddeps = Symbol[]
gf = GlobalRef(Core, :getfield)
for (i, arg) enumerate(args)
argname = gensym!(ls, "arg")
pushprepreamble!(ls, Expr(:(=), argname, Expr(:call, gf, bcargs, i, false)))
pushprepreamble!(ls, Expr(:(=), argname, Expr(:call, getfield, bcargs, i)))
# dynamic dispatch
parent = add_broadcast!(
ls,
Expand Down Expand Up @@ -539,7 +534,7 @@ end
::Val{UNROLL},
::Val{dontbc},
) where {T<:NativeTypes,N,BC<:Union{Broadcasted,Product},Mod,UNROLL,dontbc}
# 2 + 1
2 + 1
# we have an N dimensional loop.
# need to construct the LoopSet
ls = LoopSet(Mod)
Expand Down
68 changes: 40 additions & 28 deletions src/condense_loopset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ Base.:|(u::Unsigned, it::IndexType) = u | UInt8(it)
Base.:(==)(u::Unsigned, it::IndexType) = (u % UInt8) == UInt8(it)

function _append_fields!(t::Expr, body::Expr, sym::Symbol, ::Type{T}) where {T}
gf = GlobalRef(Core, :getfield)
for f 1:fieldcount(T)
TF = fieldtype(T, f)
Base.issingletontype(TF) && continue
gfcall = Expr(:call, gf, sym, f)
gfcall = Expr(:call, getfield, sym, f)
if fieldcount(TF) 0
push!(t.args, gfcall)
elseif TF <: DataType
Expand Down Expand Up @@ -37,16 +36,15 @@ end
body
end
function rebuild_fields(offset::Int, ::Type{T}) where {T}
gf = GlobalRef(Core, :getfield)
call = (T <: Tuple) ? Expr(:tuple) : Expr(:new, T)
for f 1:fieldcount(T)
TF = fieldtype(T, f)
if Base.issingletontype(TF)
push!(call.args, TF.instance)
elseif fieldcount(TF) 0
push!(call.args, Expr(:call, gf, :t, (offset += 1), false))
push!(call.args, Expr(:call, getfield, :t, (offset += 1)))
elseif TF <: DataType
push!(call.args, Expr(:call, lv(:gettype), Expr(:call, gf, :t, (offset += 1), false)))
push!(call.args, Expr(:call, lv(:gettype), Expr(:call, getfield, :t, (offset += 1))))
else
arg, offset = rebuild_fields(offset, TF)
push!(call.args, arg)
Expand All @@ -58,9 +56,9 @@ end
if Base.issingletontype(T)
return T.instance
elseif fieldcount(T) 0
call = Expr(:call, GlobalRef(Core, :getfield), :t, 1, false)
call = Expr(:call, getfield, :t, 1)
elseif T <: DataType
call = Expr(:call, lv(:gettype), Expr(:call, GlobalRef(Core, :getfield), :t, 1, false))
call = Expr(:call, lv(:gettype), Expr(:call, getfield, :t, 1))
else
call, _ = rebuild_fields(0, T)
end
Expand Down Expand Up @@ -377,10 +375,10 @@ val(x) = Expr(:call, Expr(:curly, :Val, x))
quote
$(Expr(:meta, :inline))
p, li =
VectorizationBase.tdot(x, (vsub_nsw(getfield(i, 1, false), one($I)),), strides(x))
VectorizationBase.tdot(x, (vsub_nsw(getfield(i, 1), one($I)),), strides(x))
ptr = gep(p, li)
si = ArrayInterface.StrideIndex{1,$(R[ri],),$(C === 1 ? 1 : 0)}(
(getfield(strides(x), $ri, false),),
(getfield(strides(x), $ri),),
(Zero(),),
)
stridedpointer(ptr, si, StaticInt{$(B === 1 ? 1 : 0)}())
Expand All @@ -394,8 +392,8 @@ end
quote
$(Expr(:meta, :inline))
si = ArrayInterface.StrideIndex{1,$(R[ri],),$(C === 1 ? 1 : 0)}(
(getfield(strides(x), $ri, false),),
(getfield(offsets(x), $ri, false),),
(getfield(strides(x), $ri),),
(getfield(offsets(x), $ri),),
)
stridedpointer(pointer(x), si, StaticInt{$(B == 1 ? 1 : 0)}())
end
Expand Down Expand Up @@ -550,7 +548,7 @@ function add_grouped_strided_pointer!(extra_args::Expr, ls::LoopSet)
push!(gsp.args, val(matcheddims))
gsps = gensym!(ls, "#grouped#strided#pointer#")
push!(extra_args.args, gsps)
pushpreamble!(ls, Expr(:(=), gsps, Expr(:call, GlobalRef(Core, :getfield), gsp, 1)))
pushpreamble!(ls, Expr(:(=), gsps, Expr(:call, getfield, gsp, 1)))
preserve, shouldindbyind, roots
end

Expand Down Expand Up @@ -802,21 +800,10 @@ function generate_call_types(
argmeta = argmeta_and_consts_description(ls, arraysymbolinds)
loop_bounds = loop_boundaries(ls, shouldindbyind)
loop_syms = tuple_expr(QuoteNode, ls.loopsymbols)
func = debug ? lv(:_turbo_loopset_debug) : lv(:_turbo_!)
lbarg = debug ? Expr(:call, :typeof, loop_bounds) : loop_bounds
configarg = (inline, u₁, u₂, v, ls.isbroadcast, thread, warncheckarg, safe)
unroll_param_tup =
Expr(:call, lv(:avx_config_val), :(Val{$configarg}()), VECTORWIDTHSYMBOL)
q = Expr(
:call,
func,
unroll_param_tup,
val(operation_descriptions),
val(arrayref_descriptions),
val(argmeta),
val(loop_syms),
)

add_reassigned_syms!(extra_args, ls) # counterpart to `add_ops!` constants
for (opid, sym) ls.preamble_symsym # counterpart to process_metadata! symsym extraction
if instruction(ops[opid]) DROPPEDCONSTANT
Expand All @@ -826,17 +813,42 @@ function generate_call_types(
append!(extra_args.args, arraysymbolinds) # add_array_symbols!
add_external_functions!(extra_args, ls) # extract_external_functions!
add_outerreduct_types!(extra_args, ls) # extract_outerreduct_types!
if debug
vecwidthdefq = Expr(:block)
argcestimate = length(extra_args.args) - 1
for ref = ls.refs_aliasing_syms
argcestimate += length(ref.loopedindex)
end
manyarg = !debug && (argcestimate > 16)
func = debug ? lv(:_turbo_loopset_debug) : (manyarg ? lv(:_turbo_manyarg!) : lv(:_turbo_!))
q = Expr(
:call,
func,
unroll_param_tup,
val(operation_descriptions),
val(arrayref_descriptions),
val(argmeta),
val(loop_syms),
)
vecwidthdefq = if debug
push!(q.args, Expr(:tuple, lbarg, extra_args))
Expr(:block)
else
vargsym = gensym(:vargsym)
vecwidthdefq = Expr(:block, Expr(:(=), vargsym, Expr(:tuple, lbarg, extra_args)))
push!(
q.args,
Expr(:call, GlobalRef(Base, :Val), Expr(:call, GlobalRef(Base, :typeof), vargsym)),
Expr(:(...), Expr(:call, lv(:flatten_to_tuple), vargsym)),
Expr(:call, GlobalRef(Base, :Val), Expr(:call, GlobalRef(Base, :typeof), vargsym))
)
if manyarg
push!(
q.args,
Expr(:call, lv(:flatten_to_tuple), vargsym),
)
else
push!(
q.args,
Expr(:(...), Expr(:call, lv(:flatten_to_tuple), vargsym)),
)
end
Expr(:block, Expr(:(=), vargsym, Expr(:tuple, lbarg, extra_args)))
end
define_eltype_vec_width!(vecwidthdefq, ls, nothing, true)
push!(vecwidthdefq.args, q)
Expand Down
65 changes: 61 additions & 4 deletions src/reconstruct_loopset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ function Loop(
end


extract_loop(l) = Expr(:call, GlobalRef(Core, :getfield), Symbol("#loop#bounds#"), l, false)
extract_loop(l) = Expr(:call, getfield, Symbol("#loop#bounds#"), l)

function add_loops!(ls::LoopSet, LPSYM, LB)
n = max(length(LPSYM), length(LB))
Expand All @@ -145,7 +145,7 @@ function add_loops!(
ssym = String(sym)
for k = N:-1:1
axisexpr =
:(getfield(getfield(getfield(var"#loop#bounds#", $i, false), :indices), $k, false))
:($getfield($getfield($getfield(var"#loop#bounds#", $i), :indices), $k))
add_loop!(
ls,
Loop(ls, axisexpr, Symbol(ssym * '#' * string(k) * '#'), T.parameters[k])::Loop,
Expand Down Expand Up @@ -258,7 +258,7 @@ function ArrayReferenceMeta(
end


extract_varg(i) = :(getfield(var"#vargs#", $i, false))
extract_varg(i) = :($getfield(var"#vargs#", $i))
# _extract(::Type{StaticInt{N}}) where {N} = N
extract_gsp!(sptrs::Expr, name::Symbol) = (push!(sptrs.args, name); nothing)
tupleranks(R::NTuple{8,Int}) = ntuple(n -> sum(R[n] .≥ R), Val{8}())
Expand Down Expand Up @@ -319,7 +319,7 @@ function _add_mref!(
extract_gsp!(sptrs, tmpsp)
strd_tup = Expr(:tuple)
offsets_tup = Expr(:tuple)
gf = GlobalRef(Core, :getfield)
gf = getfield
offsets = gensym(:offsets)
strides = gensym(:strides)
pushpreamble!(ls, Expr(:(=), offsets, Expr(:call, lv(:offsets), tmpsp)))
Expand Down Expand Up @@ -1019,3 +1019,60 @@ Execute an `@turbo` block. The block's code is represented via the arguments:
post === ls.preamble ? q : Expr(:block, q, post)
# @show var"#UNROLL#", var"#OPS#", var"#ARF#", var"#AM#", var"#LPSYM#", var"#LB#"
end
@generated function _turbo_manyarg!(
::Val{var"#UNROLL#"},
::Val{var"#OPS#"},
::Val{var"#ARF#"},
::Val{var"#AM#"},
::Val{var"#LPSYM#"},
::Val{Tuple{var"#LB#",var"#V#"}},
var"#flattened#var#arguments#"::Tuple{Vararg{Any,var"#num#vargs#"}},
) where {
var"#UNROLL#",
var"#OPS#",
var"#ARF#",
var"#AM#",
var"#LPSYM#",
var"#LB#",
var"#V#",
var"#num#vargs#",
}
1 + 1 # Irrelevant line you can comment out/in to force recompilation...
ls = _turbo_loopset(
var"#OPS#",
var"#ARF#",
var"#AM#",
var"#LPSYM#",
var"#LB#".parameters,
var"#V#".parameters,
var"#UNROLL#",
)
pushfirst!(
ls.preamble.args,
:(
var"#lv#tuple#args#" =
reassemble_tuple(Tuple{var"#LB#",var"#V#"}, var"#flattened#var#arguments#")
),
)
post = hoist_constant_memory_accesses!(ls)
# q = @show(avx_body(ls, var"#UNROLL#")); post === ls.preamble ? q : Expr(:block, q, post)
q = if (var"#UNROLL#"[10] > 1) && length(var"#LPSYM#") == length(ls.loops)
inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, nt, wca, safe = var"#UNROLL#"
# wrap in `var"#OPS#", var"#ARF#", var"#AM#", var"#LPSYM#"` in `Expr` to homogenize types
avx_threads_expr(
ls,
(inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, one(UInt), wca, safe),
nt,
:(Val{$(var"#OPS#")}()),
:(Val{$(var"#ARF#")}()),
:(Val{$(var"#AM#")}()),
:(Val{$(var"#LPSYM#")}()),
)
else
# Main.BODY[] = avx_body(ls, var"#UNROLL#")
# return @show avx_body(ls, var"#UNROLL#")
avx_body(ls, var"#UNROLL#")
end
post === ls.preamble ? q : Expr(:block, q, post)
# @show var"#UNROLL#", var"#OPS#", var"#ARF#", var"#AM#", var"#LPSYM#", var"#LB#"
end

2 comments on commit e5189a8

@chriselrod
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/72541

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.12.141 -m "<description of version>" e5189a8ca47ef9c410a76934e24a364d2ba44adc
git push origin v0.12.141

Please sign in to comment.