From 72bfe9371ce5b59e45a1441be760d2407b00b969 Mon Sep 17 00:00:00 2001 From: Dennis Yatunin Date: Thu, 4 Apr 2024 15:27:23 -0700 Subject: [PATCH] Add unrolled_enumerate and unrolled_applyat --- Project.toml | 4 +--- docs/src/index.md | 11 ++++++---- src/UnrolledUtilities.jl | 13 +++++++++++- test/test_and_analyze.jl | 44 ++++++++++++++++++++++++++++++++++++++-- 4 files changed, 62 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 81c32b4..62f5f0a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,7 @@ name = "UnrolledUtilities" uuid = "0fe1646c-419e-43be-ac14-22321958931b" authors = ["CliMA Contributors "] -version = "0.1.1" - -[deps] +version = "0.1.2" [compat] julia = "1.10" diff --git a/docs/src/index.md b/docs/src/index.md index e3d8b6b..8aaec38 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -8,6 +8,8 @@ A collection of generated functions in which all loops are unrolled and inlined: - `unrolled_reduce(op, itr; [init])`: similar to `reduce` - `unrolled_mapreduce(f, op, itrs...; [init])`: similar to `mapreduce` - `unrolled_zip(itrs...)`: similar to `zip` +- `unrolled_enumerate(itrs...)`: similar to `enumerate`, but with the ability to + handle multiple iterators - `unrolled_in(item, itr)`: similar to `in` - `unrolled_unique(itr)`: similar to `unique` - `unrolled_filter(f, itr)`: similar to `filter` @@ -16,10 +18,11 @@ A collection of generated functions in which all loops are unrolled and inlined: - `unrolled_flatten(itr)`: similar to `Iterators.flatten` - `unrolled_flatmap(f, itrs...)`: similar to `Iterators.flatmap` - `unrolled_product(itrs...)`: similar to `Iterators.product` -- `unrolled_take(itr, ::Val{N})`: similar to `Iterators.take` or `itr[1:N]`, but - with `N` wrapped in a `Val` -- `unrolled_drop(itr, ::Val{N})`: similar to `Iterators.drop` or - `itr[(N + 1):end]`, but with `N` wrapped in a `Val` +- `unrolled_applyat(f, n, itrs...)`: similar to `f(map(itr -> itr[n], itrs)...)` +- `unrolled_take(itr, ::Val{N})`: similar to `itr[1:N]` (and to + `Iterators.take`), but with `N` wrapped in a `Val` +- `unrolled_drop(itr, ::Val{N})`: similar to `itr[(N + 1):end]` (and to + `Iterators.drop`), but with `N` wrapped in a `Val` These functions are guaranteed to be type-stable whenever they are given iterators with inferrable lengths and element types, including when diff --git a/src/UnrolledUtilities.jl b/src/UnrolledUtilities.jl index d0931b9..dc69559 100644 --- a/src/UnrolledUtilities.jl +++ b/src/UnrolledUtilities.jl @@ -7,6 +7,7 @@ export unrolled_any, unrolled_reduce, unrolled_mapreduce, unrolled_zip, + unrolled_enumerate, unrolled_in, unrolled_unique, unrolled_filter, @@ -14,10 +15,11 @@ export unrolled_any, unrolled_flatten, unrolled_flatmap, unrolled_product, + unrolled_applyat, unrolled_take, unrolled_drop -inferred_length(itr_type::Type{<:Tuple}) = length(itr_type.types) +inferred_length(::Type{<:NTuple{N, Any}}) where {N} = N # We could also add support for statically-sized iterators that are not Tuples. f_exprs(itr_type) = (:(f(itr[$n])) for n in 1:inferred_length(itr_type)) @@ -52,6 +54,9 @@ struct NoInit end @inline unrolled_zip(itrs...) = unrolled_map(tuple, itrs...) +@inline unrolled_enumerate(itrs...) = + unrolled_zip(ntuple(identity, Val(length(itrs[1]))), itrs...) + @inline unrolled_in(item, itr) = unrolled_any(Base.Fix1(===, item), itr) # Using === instead of == or isequal improves type stability for singletons. @@ -89,6 +94,11 @@ struct NoInit end end end +@inline unrolled_applyat(f, n, itrs...) = unrolled_foreach( + (i, items...) -> i == n && f(items...), + unrolled_enumerate(itrs...), +) + @inline unrolled_take(itr, ::Val{N}) where {N} = ntuple(i -> itr[i], Val(N)) @inline unrolled_drop(itr, ::Val{N}) where {N} = ntuple(i -> itr[N + i], Val(length(itr) - N)) @@ -107,6 +117,7 @@ struct NoInit end unrolled_filter, unrolled_split, unrolled_flatmap, + unrolled_applyat, ) for method in methods(func) method.recursion_relation = (_...) -> true diff --git a/test/test_and_analyze.jl b/test/test_and_analyze.jl index 07de628..70415e3 100644 --- a/test/test_and_analyze.jl +++ b/test/test_and_analyze.jl @@ -181,14 +181,14 @@ macro test_unrolled(args_expr, unrolled_expr, reference_expr, contents_info_str) arg_definitions_str = join(arg_definition_strs, '\n') unrolled_command_str = """ using UnrolledUtilities - unrolled_func($arg_names_str) = $($unrolled_expr_str) + unrolled_func($arg_names_str) = $($(string(unrolled_expr))) $arg_definitions_str stats1 = @timed unrolled_func($arg_names_str) stats2 = @timed unrolled_func($arg_names_str) print(stats1.time - stats2.time, ',', stats1.bytes - stats2.bytes) """ reference_command_str = """ - reference_func($arg_names_str) = $($reference_expr_str) + reference_func($arg_names_str) = $($(string(reference_expr))) $arg_definitions_str stats1 = @timed reference_func($arg_names_str) stats2 = @timed reference_func($arg_names_str) @@ -344,6 +344,8 @@ for n in (1, 8, 32, 33, 128), identical in (n == 1 ? (true,) : (true, false)) @test_unrolled (itr,) unrolled_zip(itr) Tuple(zip(itr)) str + @test_unrolled (itr,) unrolled_enumerate(itr) Tuple(enumerate(itr)) str + @test_unrolled (itr,) unrolled_in(nothing, itr) (nothing in itr) str @test_unrolled (itr,) unrolled_in(itr[1], itr) (itr[1] in itr) str @test_unrolled (itr,) unrolled_in(itr[end], itr) (itr[end] in itr) str @@ -388,6 +390,17 @@ for n in (1, 8, 32, 33, 128), identical in (n == 1 ? (true,) : (true, false)) str, ) + @test_unrolled( + (itr,), + unrolled_applyat( + x -> @assert(length(x) <= 7), + rand(1:length(itr)), + itr, + ), + @assert(length(itr[rand(1:length(itr))]) <= 7), + str, + ) + if n > 1 @test_unrolled( (itr,), @@ -439,6 +452,33 @@ for n in (1, 8, 32, 33, 128), identical in (n == 1 ? (true,) : (true, false)) str23, ) + @test_unrolled( + (itr1, itr2), + unrolled_applyat( + (x1, x2) -> @assert(length(x1) < length(x2)), + rand(1:length(itr1)), + itr1, + itr2, + ), + let n = rand(1:length(itr1)) + @assert(length(itr1[n]) < length(itr2[n])) + end, + str12, + ) + @test_unrolled( + (itr2, itr3), + unrolled_applyat( + (x2, x3) -> @assert(x2 == unrolled_map(Val, x3)), + rand(1:length(itr2)), + itr2, + itr3, + ), + let n = rand(1:length(itr2)) + @assert(itr2[n] == map(Val, itr3[n])) + end, + str23, + ) + @test_unrolled( (itr1, itr2), unrolled_zip(itr1, itr2),