Skip to content

Commit

Permalink
remove redundant calls to mtlconvert
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich committed Aug 23, 2024
1 parent 28576b3 commit 15785ce
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ macro metal(ex...)
$kernel_tt = Tuple{map(Core.Typeof, $kernel_args)...}
$kernel = $mtlfunction($kernel_f, $kernel_tt; $(compiler_kwargs...))
if $launch
$kernel($(var_exprs...); $(call_kwargs...))
$kernel($kernel_args...; $(call_kwargs...))
end
$kernel
end
Expand Down Expand Up @@ -203,7 +203,7 @@ const _kernel_instances = Dict{UInt, Any}()

## kernel launching and argument encoding

@inline @generated function encode_arguments!(cce, kernel, args...)
@inline @generated function encode_arguments!(cce, kernel::HostKernel{F,TT}, args...) where {F,TT}
ex = quote
bufs = MTLBuffer[]
end
Expand All @@ -216,6 +216,8 @@ const _kernel_instances = Dict{UInt, Any}()
idx = 1
for (argidx, argtyp) in enumerate(args)
argex = :(args[$argidx])
argTT = :(TT.parameters[$argidx])

if argtyp <: MTLBuffer
# top-level buffers are passed as a pointer-valued argument
push!(ex.args, :(set_buffer!(cce, $argex, 0, $idx)))
Expand All @@ -227,7 +229,11 @@ const _kernel_instances = Dict{UInt, Any}()
else
# everything else is passed by reference, in an argument buffer
append!(ex.args, (quote
buf = encode_argument!(kernel, mtlconvert($(argex), cce))
buf = if $argtyp != $argTT
encode_argument!(kernel, mtlconvert($(argex), cce))
else
encode_argument!(kernel, $argex)
end
set_buffer!(cce, buf, 0, $idx)
push!(bufs, buf)
end).args)
Expand Down Expand Up @@ -259,8 +265,7 @@ end
return argument_buffer
end

@autoreleasepool function (kernel::HostKernel)(args...; groups=1, threads=1,
queue=global_queue(device()))
@autoreleasepool function (kernel::HostKernel)(args...; groups=1, threads=1, queue=global_queue(device()))
groups = MTLSize(groups)
threads = MTLSize(threads)
(groups.width>0 && groups.height>0 && groups.depth>0) ||
Expand All @@ -276,7 +281,7 @@ end
cce = MTLComputeCommandEncoder(cmdbuf)
argument_buffers = try
MTL.set_function!(cce, kernel.pipeline)
bufs = encode_arguments!(cce, kernel, kernel.f, args...)
bufs = encode_arguments!(cce, kernel, args...)
MTL.append_current_function!(cce, groups, threads)
bufs
finally
Expand Down

0 comments on commit 15785ce

Please sign in to comment.