Skip to content

Commit

Permalink
✨ Add helpers for multi-thread functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ronisbr committed Jun 14, 2024
1 parent 04e4309 commit 0412514
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/SatelliteToolboxBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ include("./constants.jl")

include("interfaces.jl")

include("helpers.jl")

include("./orbit/anomalies.jl")
include("./orbit/conversions.jl")
include("./orbit/kepler_to_rv.jl")
Expand Down
71 changes: 71 additions & 0 deletions src/helpers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
## Description #############################################################################
#
# General helpers for the SatelliteToolbox.jl environment.
#
############################################################################################

############################################################################################
# Macros #
############################################################################################

# == Threads ===============================================================================

"""
@maybe_threads(ntasks, expr)
Run `expr` using `Threads.@threads` if `ntasks` is larger than 1. Otherwise, just run
`expr`, avoiding the overhead.
"""
macro maybe_threads(ntasks, expr)
expr = quote
if $ntasks > 1
Threads.@threads $expr
else
$expr
end
end

return esc(expr)
end

############################################################################################
# Public Functions #
############################################################################################

# == Threads ===============================================================================

"""
get_partition(cp::Integer, v::AbstractVector, np::Integer) -> Int, Int
Return the `cp`-th partition (start and end indices) of the vector `v` considering that we
are partitioning it into `np` parts.
This function is useful to splitting input information for spawning multiple tasks.
!!! note
- The function will clamp `np` if it is larger than the number of elements in `v`.
- The function will clamp `cp` if it is larger than `np`.
# Returns
- `Int`: Current partition start index.
- `Int`: Current partition last index.
"""
function get_partition(cp::Integer, v::AbstractVector, np::Integer)
num_elements = length(v)

# Check inputs.
np = min(np, num_elements)
cp = min(cp, np)

len, rem = divrem(num_elements, np)

i₀ = firstindex(v) + (cp - 1) * len
i₁ = i₀ + len - 1

i₀ += cp <= rem ? cp - 1 : rem
i₁ += cp <= rem ? cp : rem

return i₀, i₁
end
31 changes: 31 additions & 0 deletions test/helpers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
## Description #############################################################################
#
# Tests related to helpers.
#
############################################################################################

@testset "Helpers" begin
i₀, i₁ = SatelliteToolboxBase.get_partition(1, 1:1:10, 3)
@test i₀ == 1
@test i₁ == 4

i₀, i₁ = SatelliteToolboxBase.get_partition(2, 1:1:10, 3)
@test i₀ == 5
@test i₁ == 7

i₀, i₁ = SatelliteToolboxBase.get_partition(3, 1:1:10, 3)
@test i₀ == 8
@test i₁ == 10

for i in 1:10
i₀, i₁ = SatelliteToolboxBase.get_partition(i, 1:1:10, 100)
@test i₀ == i
@test i₁ == i
end

for i in 11:100
i₀, i₁ = SatelliteToolboxBase.get_partition(i, 1:1:10, 100)
@test i₀ == 10
@test i₁ == 10
end
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ using StaticArrays
include("./ellipsoid.jl")
end

@testset "Helpers" verbose = true begin
include("./helpers.jl")
end

@testset "Interfaces" verbose = true begin
include("./interfaces.jl")
end
Expand Down

0 comments on commit 0412514

Please sign in to comment.