Skip to content

Commit

Permalink
Make better use of Metal APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
christiangnrd committed Sep 18, 2024
1 parent 6997e40 commit b6c6789
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions perf/byval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ const threads = 256

# simple add matrixes kernel
function kernel_add_mat(n, x1, x2, y)
i = (threadgroup_position_in_grid_2d().x-1) * threadgroups_per_grid_2d().x + thread_position_in_threadgroup_2d().x
i = thread_position_in_grid_1d()
if i <= n
@inbounds y[i] = x1[i] + x2[i]
end
Expand All @@ -20,7 +20,7 @@ end
# add arrays of matrixes kernel
function kernel_add_mat_z_slices(n, vararg...)
x1, x2, y = get_inputs3(threadgroup_position_in_grid_2d().y, vararg...)
i = (threadgroup_position_in_grid_2d().x-1) * threadgroups_per_grid_2d().x + thread_position_in_threadgroup_2d().x
i = thread_position_in_grid_1d()
if i <= n
@inbounds y[i] = x1[i] + x2[i]
end
Expand Down
6 changes: 3 additions & 3 deletions perf/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,22 @@ group["launch"] = @benchmarkable @metal identity(nothing)
src = Metal.rand(Float32, 512, 1000)
dest = similar(src)
function indexing_kernel(dest, src)
i = (threadgroup_position_in_grid_2d().x-1) * threadgroups_per_grid_2d().x + thread_position_in_threadgroup_2d().x
i = thread_position_in_grid_1d()
@inbounds dest[i] = src[i]
return
end
group["indexing"] = @async_benchmarkable @metal threads=size(src,1) groups=size(src,2) $indexing_kernel($dest, $src)

function checked_indexing_kernel(dest, src)
i = (threadgroup_position_in_grid_2d().x-1) * threadgroups_per_grid_2d().x + thread_position_in_threadgroup_2d().x
i = thread_position_in_grid_1d().x-1
dest[i] = src[i]
return
end
group["indexing_checked"] = @async_benchmarkable @metal threads=size(src,1) groups=size(src,2) $checked_indexing_kernel($dest, $src)

## DELETE
# function rand_kernel(dest::AbstractArray{T}) where {T}
# i = (threadgroup_position_in_grid_2d().x-1) * threadgroups_per_grid_2d().x + thread_position_in_threadgroup_2d().x
# i = thread_position_in_grid_1d()
# dest[i] = Metal.rand(T)
# return
# end
Expand Down

0 comments on commit b6c6789

Please sign in to comment.