diff --git a/src/compiler/execution.jl b/src/compiler/execution.jl index 1fd4f3b9..75339ceb 100644 --- a/src/compiler/execution.jl +++ b/src/compiler/execution.jl @@ -84,7 +84,7 @@ macro metal(ex...) $kernel_tt = Tuple{map(Core.Typeof, $kernel_args)...} $kernel = $mtlfunction($kernel_f, $kernel_tt; $(compiler_kwargs...)) if $launch - $kernel($(var_exprs...); $(call_kwargs...)) + $kernel($kernel_args...; $(call_kwargs...)) end $kernel end @@ -203,7 +203,7 @@ const _kernel_instances = Dict{UInt, Any}() ## kernel launching and argument encoding -@inline @generated function encode_arguments!(cce, kernel, args...) +@inline @generated function encode_arguments!(cce, kernel::HostKernel{F,TT}, args...) where {F,TT} ex = quote bufs = MTLBuffer[] end @@ -216,6 +216,8 @@ const _kernel_instances = Dict{UInt, Any}() idx = 1 for (argidx, argtyp) in enumerate(args) argex = :(args[$argidx]) + argTT = :(TT.parameters[$argidx]) + if argtyp <: MTLBuffer # top-level buffers are passed as a pointer-valued argument push!(ex.args, :(set_buffer!(cce, $argex, 0, $idx))) @@ -227,7 +229,11 @@ const _kernel_instances = Dict{UInt, Any}() else # everything else is passed by reference, in an argument buffer append!(ex.args, (quote - buf = encode_argument!(kernel, mtlconvert($(argex), cce)) + buf = if $argtyp != $argTT + encode_argument!(kernel, mtlconvert($(argex), cce)) + else + encode_argument!(kernel, $argex) + end set_buffer!(cce, buf, 0, $idx) push!(bufs, buf) end).args) @@ -259,8 +265,7 @@ end return argument_buffer end -@autoreleasepool function (kernel::HostKernel)(args...; groups=1, threads=1, - queue=global_queue(device())) +@autoreleasepool function (kernel::HostKernel)(args...; groups=1, threads=1, queue=global_queue(device())) groups = MTLSize(groups) threads = MTLSize(threads) (groups.width>0 && groups.height>0 && groups.depth>0) || @@ -276,7 +281,7 @@ end cce = MTLComputeCommandEncoder(cmdbuf) argument_buffers = try MTL.set_function!(cce, kernel.pipeline) - bufs = encode_arguments!(cce, kernel, kernel.f, args...) + bufs = encode_arguments!(cce, kernel, args...) MTL.append_current_function!(cce, groups, threads) bufs finally