Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add fn n_routes() to RouteProvider trait #584

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,20 @@ where
{
fn route(&self) -> Result<Url, AgentError> {
let snapshot = self.routing_snapshot.load();
let node = snapshot.next().ok_or_else(|| {
let node = snapshot.next_node().ok_or_else(|| {
AgentError::RouteProviderError("No healthy API nodes found.".to_string())
})?;
Ok(node.to_routing_url())
}

fn n_ordered_routes(&self, n: usize) -> Result<Vec<Url>, AgentError> {
let snapshot = self.routing_snapshot.load();
let nodes = snapshot.next_n_nodes(n).ok_or_else(|| {
AgentError::RouteProviderError("No healthy API nodes found.".to_string())
})?;
let urls = nodes.iter().map(|n| n.to_routing_url()).collect();
Ok(urls)
}
}

impl<S> DynamicRouteProvider<S>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ where
// - failure should never happen, but we trace it if it does
loop {
let snapshot = self.routing_snapshot.load();
if let Some(node) = snapshot.next() {
if let Some(node) = snapshot.next_node() {
match self.fetcher.fetch((&node).into()).await {
Ok(nodes) => {
let msg = Some(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ impl LatencyRoutingSnapshot {
/// Helper function to sample nodes based on their weights.
/// Here weight index is selected based on the input number in range [0, 1]
#[inline(always)]
fn weighted_sample(weights: &[f64], number: f64) -> Option<usize> {
fn weighted_sample(weighted_nodes: &[(f64, &Node)], number: f64) -> Option<usize> {
if !(0.0..=1.0).contains(&number) {
return None;
}
let sum: f64 = weights.iter().sum();
let sum: f64 = weighted_nodes.iter().map(|n| n.0).sum();
let mut weighted_number = number * sum;
for (idx, weight) in weights.iter().enumerate() {
for (idx, &(weight, _)) in weighted_nodes.iter().enumerate() {
weighted_number -= weight;
if weighted_number <= 0.0 {
return Some(idx);
Expand All @@ -69,19 +69,40 @@ impl RoutingSnapshot for LatencyRoutingSnapshot {
!self.weighted_nodes.is_empty()
}

fn next(&self) -> Option<Node> {
// We select a node based on it's weight, using a stochastic weighted random sampling approach.
let weights = self
fn next_node(&self) -> Option<Node> {
self.next_n_nodes(1).unwrap_or_default().into_iter().next()
}

// Uses weighted random sampling algorithm without item replacement n times.
fn next_n_nodes(&self, n: usize) -> Option<Vec<Node>> {
if n == 0 {
return Some(Vec::new());
}

let n = std::cmp::min(n, self.weighted_nodes.len());

let mut nodes = Vec::with_capacity(n);

let mut weighted_nodes: Vec<_> = self
.weighted_nodes
.iter()
.map(|n| n.weight)
.collect::<Vec<_>>();
// Generate a random float in the range [0, 1)
.map(|n| (n.weight, &n.node))
.collect();

let mut rng = rand::thread_rng();
let rand_num = rng.gen::<f64>();
// Using this random float and an array of weights we get an index of the node.
let idx = weighted_sample(weights.as_slice(), rand_num);
idx.map(|idx| self.weighted_nodes[idx].node.clone())

for _ in 0..n {
// Generate a random float in the range [0, 1)
let rand_num = rng.gen::<f64>();
if let Some(idx) = weighted_sample(weighted_nodes.as_slice(), rand_num) {
let node = weighted_nodes[idx].1;
nodes.push(node.clone());
// Remove the item, so that it can't be selected anymore.
weighted_nodes.swap_remove(idx);
}
}

Some(nodes)
}

fn sync_nodes(&mut self, nodes: &[Node]) -> bool {
Expand Down Expand Up @@ -143,7 +164,10 @@ impl RoutingSnapshot for LatencyRoutingSnapshot {

#[cfg(test)]
mod tests {
use std::{collections::HashSet, time::Duration};
use std::{
collections::{HashMap, HashSet},
time::Duration,
};

use simple_moving_average::SMA;

Expand All @@ -166,7 +190,7 @@ mod tests {
assert!(snapshot.weighted_nodes.is_empty());
assert!(snapshot.existing_nodes.is_empty());
assert!(!snapshot.has_nodes());
assert!(snapshot.next().is_none());
assert!(snapshot.next_node().is_none());
}

#[test]
Expand All @@ -181,7 +205,7 @@ mod tests {
assert!(!is_updated);
assert!(snapshot.weighted_nodes.is_empty());
assert!(!snapshot.has_nodes());
assert!(snapshot.next().is_none());
assert!(snapshot.next_node().is_none());
}

#[test]
Expand All @@ -201,7 +225,7 @@ mod tests {
Duration::from_secs(1)
);
assert_eq!(weighted_node.weight, 1.0);
assert_eq!(snapshot.next().unwrap(), node);
assert_eq!(snapshot.next_node().unwrap(), node);
// Check second update
let health = HealthCheckStatus::new(Some(Duration::from_secs(2)));
let is_updated = snapshot.update_node(&node, health);
Expand Down Expand Up @@ -232,7 +256,7 @@ mod tests {
assert_eq!(weighted_node.weight, 1.0 / avg_latency.as_secs_f64());
assert_eq!(snapshot.weighted_nodes.len(), 1);
assert_eq!(snapshot.existing_nodes.len(), 1);
assert_eq!(snapshot.next().unwrap(), node);
assert_eq!(snapshot.next_node().unwrap(), node);
}

#[test]
Expand Down Expand Up @@ -307,12 +331,13 @@ mod tests {

#[test]
fn test_weighted_sample() {
let node = &Node::new("api1.com").unwrap();
// Case 1: empty array
let arr: &[f64] = &[];
let arr = &[];
let idx = weighted_sample(arr, 0.5);
assert_eq!(idx, None);
// Case 2: single element in array
let arr: &[f64] = &[1.0];
let arr = &[(1.0, node)];
let idx = weighted_sample(arr, 0.0);
assert_eq!(idx, Some(0));
let idx = weighted_sample(arr, 1.0);
Expand All @@ -323,7 +348,7 @@ mod tests {
let idx = weighted_sample(arr, 1.1);
assert_eq!(idx, None);
// Case 3: two elements in array (second element has twice the weight of the first)
let arr: &[f64] = &[1.0, 2.0]; // prefixed_sum = [1.0, 3.0]
let arr = &[(1.0, node), (2.0, node)]; // // prefixed_sum = [1.0, 3.0]
let idx = weighted_sample(arr, 0.0); // 0.0 * 3.0 < 1.0
assert_eq!(idx, Some(0));
let idx = weighted_sample(arr, 0.33); // 0.33 * 3.0 < 1.0
Expand All @@ -338,7 +363,7 @@ mod tests {
let idx = weighted_sample(arr, 1.1);
assert_eq!(idx, None);
// Case 4: four elements in array
let arr: &[f64] = &[1.0, 2.0, 1.5, 2.5]; // prefixed_sum = [1.0, 3.0, 4.5, 7.0]
let arr = &[(1.0, node), (2.0, node), (1.5, node), (2.5, node)]; // prefixed_sum = [1.0, 3.0, 4.5, 7.0]
let idx = weighted_sample(arr, 0.14); // 0.14 * 7 < 1.0
assert_eq!(idx, Some(0)); // probability ~0.14
let idx = weighted_sample(arr, 0.15); // 0.15 * 7 > 1.0
Expand All @@ -359,4 +384,69 @@ mod tests {
let idx = weighted_sample(arr, 1.1);
assert_eq!(idx, None);
}

#[test]
// #[ignore]
// This test is for manual runs to see the statistics for nodes selection probability.
fn test_stats_for_next_n_nodes() {
// Arrange
let mut snapshot = LatencyRoutingSnapshot::new();
let node_1 = Node::new("api1.com").unwrap();
let node_2 = Node::new("api2.com").unwrap();
let node_3 = Node::new("api3.com").unwrap();
let node_4 = Node::new("api4.com").unwrap();
let node_5 = Node::new("api5.com").unwrap();
let node_6 = Node::new("api6.com").unwrap();
let latency_mov_avg = LatencyMovAvg::from_zero(Duration::ZERO);
snapshot.weighted_nodes = vec![
WeightedNode {
node: node_2.clone(),
latency_mov_avg: latency_mov_avg.clone(),
weight: 8.0,
},
WeightedNode {
node: node_3.clone(),
latency_mov_avg: latency_mov_avg.clone(),
weight: 4.0,
},
WeightedNode {
node: node_1.clone(),
latency_mov_avg: latency_mov_avg.clone(),
weight: 16.0,
},
WeightedNode {
node: node_6.clone(),
latency_mov_avg: latency_mov_avg.clone(),
weight: 2.0,
},
WeightedNode {
node: node_5.clone(),
latency_mov_avg: latency_mov_avg.clone(),
weight: 1.0,
},
WeightedNode {
node: node_4.clone(),
latency_mov_avg: latency_mov_avg.clone(),
weight: 4.1,
},
];

let mut stats = HashMap::new();
let experiments = 30;
let select_nodes_count = 10;
for i in 0..experiments {
let nodes = snapshot.next_n_nodes(select_nodes_count).unwrap();
println!("Experiment {i}: selected nodes {nodes:?}");
for item in nodes.into_iter() {
*stats.entry(item).or_insert(1) += 1;
}
}
for (node, count) in stats {
println!(
"Node {:?} is selected with probability {}",
node.domain(),
count as f64 / experiments as f64
);
}
}
}
Loading