-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ Add helpers for multi-thread functions
- Loading branch information
Showing
4 changed files
with
108 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
## Description ############################################################################# | ||
# | ||
# General helpers for the SatelliteToolbox.jl environment. | ||
# | ||
############################################################################################ | ||
|
||
############################################################################################ | ||
# Macros # | ||
############################################################################################ | ||
|
||
# == Threads =============================================================================== | ||
|
||
""" | ||
@maybe_threads(ntasks, expr) | ||
Run `expr` using `Threads.@threads` if `ntasks` is larger than 1. Otherwise, just run | ||
`expr`, avoiding the overhead. | ||
""" | ||
macro maybe_threads(ntasks, expr) | ||
expr = quote | ||
if $ntasks > 1 | ||
Threads.@threads $expr | ||
else | ||
$expr | ||
end | ||
end | ||
|
||
return esc(expr) | ||
end | ||
|
||
############################################################################################ | ||
# Public Functions # | ||
############################################################################################ | ||
|
||
# == Threads =============================================================================== | ||
|
||
""" | ||
get_partition(cp::Integer, v::AbstractVector, np::Integer) -> Int, Int | ||
Return the `cp`-th partition (start and end indices) of the vector `v` considering that we | ||
are partitioning it into `np` parts. | ||
This function is useful to splitting input information for spawning multiple tasks. | ||
!!! note | ||
- The function will clamp `np` if it is larger than the number of elements in `v`. | ||
- The function will clamp `cp` if it is larger than `np`. | ||
# Returns | ||
- `Int`: Current partition start index. | ||
- `Int`: Current partition last index. | ||
""" | ||
function get_partition(cp::Integer, v::AbstractVector, np::Integer) | ||
num_elements = length(v) | ||
|
||
# Check inputs. | ||
np = min(np, num_elements) | ||
cp = min(cp, np) | ||
|
||
len, rem = divrem(num_elements, np) | ||
|
||
i₀ = firstindex(v) + (cp - 1) * len | ||
i₁ = i₀ + len - 1 | ||
|
||
i₀ += cp <= rem ? cp - 1 : rem | ||
i₁ += cp <= rem ? cp : rem | ||
|
||
return i₀, i₁ | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
## Description ############################################################################# | ||
# | ||
# Tests related to helpers. | ||
# | ||
############################################################################################ | ||
|
||
@testset "Helpers" begin | ||
i₀, i₁ = SatelliteToolboxBase.get_partition(1, 1:1:10, 3) | ||
@test i₀ == 1 | ||
@test i₁ == 4 | ||
|
||
i₀, i₁ = SatelliteToolboxBase.get_partition(2, 1:1:10, 3) | ||
@test i₀ == 5 | ||
@test i₁ == 7 | ||
|
||
i₀, i₁ = SatelliteToolboxBase.get_partition(3, 1:1:10, 3) | ||
@test i₀ == 8 | ||
@test i₁ == 10 | ||
|
||
for i in 1:10 | ||
i₀, i₁ = SatelliteToolboxBase.get_partition(i, 1:1:10, 100) | ||
@test i₀ == i | ||
@test i₁ == i | ||
end | ||
|
||
for i in 11:100 | ||
i₀, i₁ = SatelliteToolboxBase.get_partition(i, 1:1:10, 100) | ||
@test i₀ == 10 | ||
@test i₁ == 10 | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters