From 55d3486fffd7e22a6f9b8870fb1daa304b8e720e Mon Sep 17 00:00:00 2001 From: Siraaj Khandkar Date: Mon, 13 Jun 2022 15:38:55 -0700 Subject: [PATCH] Implement random elements selection from a stream via the optimal reservoir sampling algorithm. --- src/data/data_stream.erl | 129 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 128 insertions(+), 1 deletion(-) diff --git a/src/data/data_stream.erl b/src/data/data_stream.erl index 7bfdc8d24f..424febc9e5 100644 --- a/src/data/data_stream.erl +++ b/src/data/data_stream.erl @@ -16,11 +16,14 @@ lazy_map/2, lazy_filter/2, pmap_to_bag/2, - pmap_to_bag/3 + pmap_to_bag/3, + random_elements/2 ]). -define(T, ?MODULE). +-type reservoir(A) :: #{pos_integer() => A}. + -type filter(A, B) :: {map, fun((A) -> B)} | {test, fun((A) -> boolean())} @@ -190,8 +193,87 @@ pmap_to_bag(T, F, J) when is_function(F), is_integer(J), J > 0 -> error({data_stream_scheduler_crashed_before_sending_results, Reason}) end. +-spec random_elements(t(A), non_neg_integer()) -> [A]. +random_elements(_, 0) -> []; +random_elements(T, K) when K > 0 -> + {_N, Reservoir} = reservoir_sample(T, #{}, K), + [X || {_, X} <- maps:to_list(Reservoir)]. + %% Internal =================================================================== +%% @doc +%% The optimal reservoir sampling algorithm. Known as "Algorithm L" in: +%% https://dl.acm.org/doi/pdf/10.1145/198429.198435 +%% https://en.wikipedia.org/wiki/Reservoir_sampling#An_optimal_algorithm +%% @end +-spec reservoir_sample(t(A), reservoir(A), pos_integer()) -> + {pos_integer(), reservoir(A)}. +reservoir_sample(T0, R0, K) -> + case reservoir_sample_init(T0, R0, 1, K) of + {none, R1, I} -> + {I, R1}; + {{some, T1}, R1, I} -> + W = random_weight_init(K), + J = random_index_next(I, W), + reservoir_sample_update(T1, R1, W, I, J, K) + end. + +-spec reservoir_sample_init(t(A), reservoir(A), pos_integer(), pos_integer()) -> + {none | {some, A}, reservoir(A), pos_integer()}. +reservoir_sample_init(T0, R, I, K) -> + case I > K of + true -> + {{some, T0}, R, I - 1}; + false -> + case next(T0) of + {some, {X, T1}} -> + reservoir_sample_init(T1, R#{I => X}, I + 1, K); + none -> + {none, R, I - 1} + end + end. + +-spec random_weight_init(pos_integer()) -> float(). +random_weight_init(K) -> + math:exp(math:log(rand:uniform()) / K). + +-spec random_weight_next(float(), pos_integer()) -> float(). +random_weight_next(W, K) -> + W * random_weight_init(K). + +-spec random_index_next(pos_integer(), float()) -> pos_integer(). +random_index_next(I, W) -> + I + floor(math:log(rand:uniform()) / math:log(1 - W)) + 1. + +-spec reservoir_sample_update( + t(A), + reservoir(A), + float(), + pos_integer(), + pos_integer(), + pos_integer() +) -> + {pos_integer(), reservoir(A)}. +reservoir_sample_update(T0, R0, W0, I0, J0, K) -> + case next(T0) of + none -> + {I0, R0}; + {some, {X, T1}} -> + I1 = I0 + 1, + case I0 =:= J0 of + true -> + R1 = R0#{rand:uniform(K) => X}, + W1 = random_weight_next(W0, K), + J1 = random_index_next(J0, W0), + reservoir_sample_update(T1, R1, W1, I1, J1, K); + false -> + % Here is where the big win takes place over the simple + % Algorithm R. We skip computing random numbers for an + % element that will not be picked. + reservoir_sample_update(T1, R0, W0, I1, J0, K) + end + end. + -spec sched(#sched{}) -> [any()]. sched(#sched{id=_, producers=[], consumers=[], consumers_free=[], work=[], results=Ys}) -> Ys; @@ -396,4 +478,49 @@ fold_test_() -> ] ]. +random_elements_test_() -> + TestCases = + [ + ?_assertMatch([a], random_elements(from_list([a]), 1)), + ?_assertEqual(0, length(random_elements(from_list([]), 1))), + ?_assertEqual(0, length(random_elements(from_list([]), 10))), + ?_assertEqual(0, length(random_elements(from_list([]), 100))), + ?_assertEqual(1, length(random_elements(from_list(lists:seq(1, 100)), 1))), + ?_assertEqual(2, length(random_elements(from_list(lists:seq(1, 100)), 2))), + ?_assertEqual(3, length(random_elements(from_list(lists:seq(1, 100)), 3))), + ?_assertEqual(5, length(random_elements(from_list(lists:seq(1, 100)), 5))) + | + [ + (fun () -> + Trials = 10, + K = floor(N * KF), + L = lists:seq(1, N), + S = from_list(L), + Rands = + [ + random_elements(S, K) + || + _ <- lists:duplicate(Trials, {}) + ], + Head = lists:sublist(L, K), + Unique = lists:usort(Rands) -- [Head], + Name = + lists:flatten(io_lib:format( + "At least 1/~p of trials makes a new sequence. " + "N:~p K:~p KF:~p length(Unique):~p", + [Trials, N, K, KF, length(Unique)] + )), + {Name, ?_assertMatch([_|_], Unique)} + end)() + || + N <- lists:seq(10, 100), + KF <- [ + 0.25, + 0.50, + 0.75 + ] + ] + ], + {inparallel, TestCases}. + -endif.