diff --git a/ic-agent/src/agent/http_transport/dynamic_routing/dynamic_route_provider.rs b/ic-agent/src/agent/http_transport/dynamic_routing/dynamic_route_provider.rs index cb657ae2..855fe566 100644 --- a/ic-agent/src/agent/http_transport/dynamic_routing/dynamic_route_provider.rs +++ b/ic-agent/src/agent/http_transport/dynamic_routing/dynamic_route_provider.rs @@ -170,11 +170,20 @@ where { fn route(&self) -> Result { let snapshot = self.routing_snapshot.load(); - let node = snapshot.next().ok_or_else(|| { + let node = snapshot.next_node().ok_or_else(|| { AgentError::RouteProviderError("No healthy API nodes found.".to_string()) })?; Ok(node.to_routing_url()) } + + fn n_ordered_routes(&self, n: usize) -> Result, AgentError> { + let snapshot = self.routing_snapshot.load(); + let nodes = snapshot.next_n_nodes(n).ok_or_else(|| { + AgentError::RouteProviderError("No healthy API nodes found.".to_string()) + })?; + let urls = nodes.iter().map(|n| n.to_routing_url()).collect(); + Ok(urls) + } } impl DynamicRouteProvider diff --git a/ic-agent/src/agent/http_transport/dynamic_routing/nodes_fetch.rs b/ic-agent/src/agent/http_transport/dynamic_routing/nodes_fetch.rs index 7e01d145..e887e668 100644 --- a/ic-agent/src/agent/http_transport/dynamic_routing/nodes_fetch.rs +++ b/ic-agent/src/agent/http_transport/dynamic_routing/nodes_fetch.rs @@ -143,7 +143,7 @@ where // - failure should never happen, but we trace it if it does loop { let snapshot = self.routing_snapshot.load(); - if let Some(node) = snapshot.next() { + if let Some(node) = snapshot.next_node() { match self.fetcher.fetch((&node).into()).await { Ok(nodes) => { let msg = Some( diff --git a/ic-agent/src/agent/http_transport/dynamic_routing/snapshot/latency_based_routing.rs b/ic-agent/src/agent/http_transport/dynamic_routing/snapshot/latency_based_routing.rs index 1ae10136..7de1bbc6 100644 --- a/ic-agent/src/agent/http_transport/dynamic_routing/snapshot/latency_based_routing.rs +++ b/ic-agent/src/agent/http_transport/dynamic_routing/snapshot/latency_based_routing.rs @@ -49,13 +49,13 @@ impl LatencyRoutingSnapshot { /// Helper function to sample nodes based on their weights. /// Here weight index is selected based on the input number in range [0, 1] #[inline(always)] -fn weighted_sample(weights: &[f64], number: f64) -> Option { +fn weighted_sample(weighted_nodes: &[(f64, &Node)], number: f64) -> Option { if !(0.0..=1.0).contains(&number) { return None; } - let sum: f64 = weights.iter().sum(); + let sum: f64 = weighted_nodes.iter().map(|n| n.0).sum(); let mut weighted_number = number * sum; - for (idx, weight) in weights.iter().enumerate() { + for (idx, &(weight, _)) in weighted_nodes.iter().enumerate() { weighted_number -= weight; if weighted_number <= 0.0 { return Some(idx); @@ -69,19 +69,40 @@ impl RoutingSnapshot for LatencyRoutingSnapshot { !self.weighted_nodes.is_empty() } - fn next(&self) -> Option { - // We select a node based on it's weight, using a stochastic weighted random sampling approach. - let weights = self + fn next_node(&self) -> Option { + self.next_n_nodes(1).unwrap_or_default().into_iter().next() + } + + // Uses weighted random sampling algorithm without item replacement n times. + fn next_n_nodes(&self, n: usize) -> Option> { + if n == 0 { + return Some(Vec::new()); + } + + let n = std::cmp::min(n, self.weighted_nodes.len()); + + let mut nodes = Vec::with_capacity(n); + + let mut weighted_nodes: Vec<_> = self .weighted_nodes .iter() - .map(|n| n.weight) - .collect::>(); - // Generate a random float in the range [0, 1) + .map(|n| (n.weight, &n.node)) + .collect(); + let mut rng = rand::thread_rng(); - let rand_num = rng.gen::(); - // Using this random float and an array of weights we get an index of the node. - let idx = weighted_sample(weights.as_slice(), rand_num); - idx.map(|idx| self.weighted_nodes[idx].node.clone()) + + for _ in 0..n { + // Generate a random float in the range [0, 1) + let rand_num = rng.gen::(); + if let Some(idx) = weighted_sample(weighted_nodes.as_slice(), rand_num) { + let node = weighted_nodes[idx].1; + nodes.push(node.clone()); + // Remove the item, so that it can't be selected anymore. + weighted_nodes.swap_remove(idx); + } + } + + Some(nodes) } fn sync_nodes(&mut self, nodes: &[Node]) -> bool { @@ -143,7 +164,10 @@ impl RoutingSnapshot for LatencyRoutingSnapshot { #[cfg(test)] mod tests { - use std::{collections::HashSet, time::Duration}; + use std::{ + collections::{HashMap, HashSet}, + time::Duration, + }; use simple_moving_average::SMA; @@ -166,7 +190,7 @@ mod tests { assert!(snapshot.weighted_nodes.is_empty()); assert!(snapshot.existing_nodes.is_empty()); assert!(!snapshot.has_nodes()); - assert!(snapshot.next().is_none()); + assert!(snapshot.next_node().is_none()); } #[test] @@ -181,7 +205,7 @@ mod tests { assert!(!is_updated); assert!(snapshot.weighted_nodes.is_empty()); assert!(!snapshot.has_nodes()); - assert!(snapshot.next().is_none()); + assert!(snapshot.next_node().is_none()); } #[test] @@ -201,7 +225,7 @@ mod tests { Duration::from_secs(1) ); assert_eq!(weighted_node.weight, 1.0); - assert_eq!(snapshot.next().unwrap(), node); + assert_eq!(snapshot.next_node().unwrap(), node); // Check second update let health = HealthCheckStatus::new(Some(Duration::from_secs(2))); let is_updated = snapshot.update_node(&node, health); @@ -232,7 +256,7 @@ mod tests { assert_eq!(weighted_node.weight, 1.0 / avg_latency.as_secs_f64()); assert_eq!(snapshot.weighted_nodes.len(), 1); assert_eq!(snapshot.existing_nodes.len(), 1); - assert_eq!(snapshot.next().unwrap(), node); + assert_eq!(snapshot.next_node().unwrap(), node); } #[test] @@ -307,12 +331,13 @@ mod tests { #[test] fn test_weighted_sample() { + let node = &Node::new("api1.com").unwrap(); // Case 1: empty array - let arr: &[f64] = &[]; + let arr = &[]; let idx = weighted_sample(arr, 0.5); assert_eq!(idx, None); // Case 2: single element in array - let arr: &[f64] = &[1.0]; + let arr = &[(1.0, node)]; let idx = weighted_sample(arr, 0.0); assert_eq!(idx, Some(0)); let idx = weighted_sample(arr, 1.0); @@ -323,7 +348,7 @@ mod tests { let idx = weighted_sample(arr, 1.1); assert_eq!(idx, None); // Case 3: two elements in array (second element has twice the weight of the first) - let arr: &[f64] = &[1.0, 2.0]; // prefixed_sum = [1.0, 3.0] + let arr = &[(1.0, node), (2.0, node)]; // // prefixed_sum = [1.0, 3.0] let idx = weighted_sample(arr, 0.0); // 0.0 * 3.0 < 1.0 assert_eq!(idx, Some(0)); let idx = weighted_sample(arr, 0.33); // 0.33 * 3.0 < 1.0 @@ -338,7 +363,7 @@ mod tests { let idx = weighted_sample(arr, 1.1); assert_eq!(idx, None); // Case 4: four elements in array - let arr: &[f64] = &[1.0, 2.0, 1.5, 2.5]; // prefixed_sum = [1.0, 3.0, 4.5, 7.0] + let arr = &[(1.0, node), (2.0, node), (1.5, node), (2.5, node)]; // prefixed_sum = [1.0, 3.0, 4.5, 7.0] let idx = weighted_sample(arr, 0.14); // 0.14 * 7 < 1.0 assert_eq!(idx, Some(0)); // probability ~0.14 let idx = weighted_sample(arr, 0.15); // 0.15 * 7 > 1.0 @@ -359,4 +384,69 @@ mod tests { let idx = weighted_sample(arr, 1.1); assert_eq!(idx, None); } + + #[test] + // #[ignore] + // This test is for manual runs to see the statistics for nodes selection probability. + fn test_stats_for_next_n_nodes() { + // Arrange + let mut snapshot = LatencyRoutingSnapshot::new(); + let node_1 = Node::new("api1.com").unwrap(); + let node_2 = Node::new("api2.com").unwrap(); + let node_3 = Node::new("api3.com").unwrap(); + let node_4 = Node::new("api4.com").unwrap(); + let node_5 = Node::new("api5.com").unwrap(); + let node_6 = Node::new("api6.com").unwrap(); + let latency_mov_avg = LatencyMovAvg::from_zero(Duration::ZERO); + snapshot.weighted_nodes = vec![ + WeightedNode { + node: node_2.clone(), + latency_mov_avg: latency_mov_avg.clone(), + weight: 8.0, + }, + WeightedNode { + node: node_3.clone(), + latency_mov_avg: latency_mov_avg.clone(), + weight: 4.0, + }, + WeightedNode { + node: node_1.clone(), + latency_mov_avg: latency_mov_avg.clone(), + weight: 16.0, + }, + WeightedNode { + node: node_6.clone(), + latency_mov_avg: latency_mov_avg.clone(), + weight: 2.0, + }, + WeightedNode { + node: node_5.clone(), + latency_mov_avg: latency_mov_avg.clone(), + weight: 1.0, + }, + WeightedNode { + node: node_4.clone(), + latency_mov_avg: latency_mov_avg.clone(), + weight: 4.1, + }, + ]; + + let mut stats = HashMap::new(); + let experiments = 30; + let select_nodes_count = 10; + for i in 0..experiments { + let nodes = snapshot.next_n_nodes(select_nodes_count).unwrap(); + println!("Experiment {i}: selected nodes {nodes:?}"); + for item in nodes.into_iter() { + *stats.entry(item).or_insert(1) += 1; + } + } + for (node, count) in stats { + println!( + "Node {:?} is selected with probability {}", + node.domain(), + count as f64 / experiments as f64 + ); + } + } } diff --git a/ic-agent/src/agent/http_transport/dynamic_routing/snapshot/round_robin_routing.rs b/ic-agent/src/agent/http_transport/dynamic_routing/snapshot/round_robin_routing.rs index 149e49d2..318eb379 100644 --- a/ic-agent/src/agent/http_transport/dynamic_routing/snapshot/round_robin_routing.rs +++ b/ic-agent/src/agent/http_transport/dynamic_routing/snapshot/round_robin_routing.rs @@ -34,7 +34,7 @@ impl RoutingSnapshot for RoundRobinRoutingSnapshot { !self.healthy_nodes.is_empty() } - fn next(&self) -> Option { + fn next_node(&self) -> Option { if self.healthy_nodes.is_empty() { return None; } @@ -45,6 +45,31 @@ impl RoutingSnapshot for RoundRobinRoutingSnapshot { .cloned() } + fn next_n_nodes(&self, n: usize) -> Option> { + if n == 0 { + return Some(Vec::new()); + } + + let healthy_nodes = Vec::from_iter(self.healthy_nodes.clone()); + let healthy_count = healthy_nodes.len(); + + if n >= healthy_count { + return Some(healthy_nodes.clone()); + } + + let idx = self.current_idx.fetch_add(n, Ordering::Relaxed) % healthy_count; + let mut nodes = Vec::with_capacity(n); + + if healthy_count - idx >= n { + nodes.extend_from_slice(&healthy_nodes[idx..idx + n]); + } else { + nodes.extend_from_slice(&healthy_nodes[idx..]); + nodes.extend_from_slice(&healthy_nodes[..n - nodes.len()]); + } + + Some(nodes) + } + fn sync_nodes(&mut self, nodes: &[Node]) -> bool { let new_nodes = HashSet::from_iter(nodes.iter().cloned()); // Find nodes removed from topology. @@ -85,6 +110,7 @@ impl RoutingSnapshot for RoundRobinRoutingSnapshot { #[cfg(test)] mod tests { + use std::collections::HashMap; use std::time::Duration; use std::{collections::HashSet, sync::atomic::Ordering}; @@ -105,7 +131,7 @@ mod tests { assert!(snapshot.existing_nodes.is_empty()); assert!(!snapshot.has_nodes()); assert_eq!(snapshot.current_idx.load(Ordering::SeqCst), 0); - assert!(snapshot.next().is_none()); + assert!(snapshot.next_node().is_none()); } #[test] @@ -121,13 +147,13 @@ mod tests { // Assert assert!(!is_updated); assert!(snapshot.existing_nodes.is_empty()); - assert!(snapshot.next().is_none()); + assert!(snapshot.next_node().is_none()); // Act 2 let is_updated = snapshot.update_node(&node, unhealthy); // Assert assert!(!is_updated); assert!(snapshot.existing_nodes.is_empty()); - assert!(snapshot.next().is_none()); + assert!(snapshot.next_node().is_none()); } #[test] @@ -142,7 +168,7 @@ mod tests { let is_updated = snapshot.update_node(&node, health); assert!(is_updated); assert!(snapshot.has_nodes()); - assert_eq!(snapshot.next().unwrap(), node); + assert_eq!(snapshot.next_node().unwrap(), node); assert_eq!(snapshot.current_idx.load(Ordering::SeqCst), 1); } @@ -158,7 +184,7 @@ mod tests { let is_updated = snapshot.update_node(&node, unhealthy); assert!(is_updated); assert!(!snapshot.has_nodes()); - assert!(snapshot.next().is_none()); + assert!(snapshot.next_node().is_none()); } #[test] @@ -217,4 +243,84 @@ mod tests { assert!(!nodes_changed); assert!(snapshot.existing_nodes.is_empty()); } + + #[test] + fn test_next_node() { + // Arrange + let mut snapshot = RoundRobinRoutingSnapshot::new(); + let node_1 = Node::new("api1.com").unwrap(); + let node_2 = Node::new("api2.com").unwrap(); + let node_3 = Node::new("api3.com").unwrap(); + let nodes = vec![node_1, node_2, node_3]; + snapshot.existing_nodes.extend(nodes.clone()); + snapshot.healthy_nodes.extend(nodes.clone()); + // Act + let n = 6; + let mut count_map = HashMap::new(); + for _ in 0..n { + let node = snapshot.next_node().unwrap(); + count_map.entry(node).and_modify(|v| *v += 1).or_insert(1); + } + // Assert each node was returned 2 times + let k = 2; + assert_eq!( + count_map.len(), + nodes.len(), + "The number of unique elements is not {}", + nodes.len() + ); + for (item, &count) in &count_map { + assert_eq!( + count, k, + "Element {:?} does not appear exactly {} times", + item, k + ); + } + } + + #[test] + fn test_n_nodes() { + // Arrange + let mut snapshot = RoundRobinRoutingSnapshot::new(); + let node_1 = Node::new("api1.com").unwrap(); + let node_2 = Node::new("api2.com").unwrap(); + let node_3 = Node::new("api3.com").unwrap(); + let node_4 = Node::new("api4.com").unwrap(); + let node_5 = Node::new("api5.com").unwrap(); + let nodes = vec![ + node_1.clone(), + node_2.clone(), + node_3.clone(), + node_4.clone(), + node_5.clone(), + ]; + snapshot.healthy_nodes.extend(nodes.clone()); + // First call + let mut n_nodes: Vec<_> = snapshot.next_n_nodes(3).expect("failed to get nodes"); + // Second call + n_nodes.extend(snapshot.next_n_nodes(3).expect("failed to get nodes")); + // Third call + n_nodes.extend(snapshot.next_n_nodes(4).expect("failed to get nodes")); + // Fourth call + n_nodes.extend(snapshot.next_n_nodes(5).expect("failed to get nodes")); + // Assert each node was returned 3 times + let k = 3; + let mut count_map = HashMap::new(); + for item in n_nodes.iter() { + count_map.entry(item).and_modify(|v| *v += 1).or_insert(1); + } + assert_eq!( + count_map.len(), + nodes.len(), + "The number of unique elements is not {}", + nodes.len() + ); + for (item, &count) in &count_map { + assert_eq!( + count, k, + "Element {:?} does not appear exactly {} times", + item, k + ); + } + } } diff --git a/ic-agent/src/agent/http_transport/dynamic_routing/snapshot/routing_snapshot.rs b/ic-agent/src/agent/http_transport/dynamic_routing/snapshot/routing_snapshot.rs index 155b8eac..242abdfe 100644 --- a/ic-agent/src/agent/http_transport/dynamic_routing/snapshot/routing_snapshot.rs +++ b/ic-agent/src/agent/http_transport/dynamic_routing/snapshot/routing_snapshot.rs @@ -6,8 +6,10 @@ use crate::agent::http_transport::dynamic_routing::{health_check::HealthCheckSta pub trait RoutingSnapshot: Send + Sync + Clone + Debug { /// Returns `true` if the snapshot has nodes. fn has_nodes(&self) -> bool; - /// Get the next node in the snapshot. - fn next(&self) -> Option; + /// Get next node from the snapshot. + fn next_node(&self) -> Option; + /// Get up to n different nodes from the snapshot. + fn next_n_nodes(&self, n: usize) -> Option>; /// Syncs the nodes in the snapshot with the provided list of nodes, returning `true` if the snapshot was updated. fn sync_nodes(&mut self, nodes: &[Node]) -> bool; /// Updates the health status of a specific node, returning `true` if the node was found and updated. diff --git a/ic-agent/src/agent/http_transport/route_provider.rs b/ic-agent/src/agent/http_transport/route_provider.rs index 608b2888..d8e7a8e3 100644 --- a/ic-agent/src/agent/http_transport/route_provider.rs +++ b/ic-agent/src/agent/http_transport/route_provider.rs @@ -15,8 +15,19 @@ use crate::agent::{ /// A [`RouteProvider`] for dynamic generation of routing urls. pub trait RouteProvider: std::fmt::Debug + Send + Sync { - /// Generate next routing url + /// Generates the next routing URL based on the internal routing logic. + /// + /// This method returns a single `Url` that can be used for routing. + /// The logic behind determining the next URL can vary depending on the implementation fn route(&self) -> Result; + + /// Generates up to `n` different routing URLs in order of priority. + /// + /// This method returns a vector of `Url` instances, each representing a routing + /// endpoint. The URLs are ordered by priority, with the most preferred route + /// appearing first. The returned vector can contain fewer than `n` URLs if + /// fewer are available. + fn n_ordered_routes(&self, n: usize) -> Result, AgentError>; } /// A simple implementation of the [`RouteProvider`] which produces an even distribution of the urls from the input ones. @@ -38,6 +49,28 @@ impl RouteProvider for RoundRobinRouteProvider { let prev_idx = self.current_idx.fetch_add(1, Ordering::Relaxed); Ok(self.routes[prev_idx % self.routes.len()].clone()) } + + fn n_ordered_routes(&self, n: usize) -> Result, AgentError> { + if n == 0 { + return Ok(Vec::new()); + } + + if n >= self.routes.len() { + return Ok(self.routes.clone()); + } + + let idx = self.current_idx.fetch_add(n, Ordering::Relaxed) % self.routes.len(); + let mut urls = Vec::with_capacity(n); + + if self.routes.len() - idx >= n { + urls.extend_from_slice(&self.routes[idx..idx + n]); + } else { + urls.extend_from_slice(&self.routes[idx..]); + urls.extend_from_slice(&self.routes[..n - urls.len()]); + } + + Ok(urls) + } } impl RoundRobinRouteProvider { @@ -99,4 +132,56 @@ mod tests { .collect(); assert_eq!(expected_urls, urls); } + + #[test] + fn test_n_routes() { + // Test with an empty list of urls + let provider = RoundRobinRouteProvider::new(Vec::<&str>::new()) + .expect("failed to create a route provider"); + let urls_iter = provider.n_ordered_routes(1).expect("failed to get urls"); + assert!(urls_iter.is_empty()); + // Test with non-empty list of urls + let provider = RoundRobinRouteProvider::new(vec![ + "https://url1.com", + "https://url2.com", + "https://url3.com", + "https://url4.com", + "https://url5.com", + ]) + .expect("failed to create a route provider"); + // First call + let urls: Vec<_> = provider.n_ordered_routes(3).expect("failed to get urls"); + let expected_urls: Vec = ["https://url1.com", "https://url2.com", "https://url3.com"] + .iter() + .map(|url_str| Url::parse(url_str).expect("invalid URL")) + .collect(); + assert_eq!(urls, expected_urls); + // Second call + let urls: Vec<_> = provider.n_ordered_routes(3).expect("failed to get urls"); + let expected_urls: Vec = ["https://url4.com", "https://url5.com", "https://url1.com"] + .iter() + .map(|url_str| Url::parse(url_str).expect("invalid URL")) + .collect(); + assert_eq!(urls, expected_urls); + // Third call + let urls: Vec<_> = provider.n_ordered_routes(2).expect("failed to get urls"); + let expected_urls: Vec = ["https://url2.com", "https://url3.com"] + .iter() + .map(|url_str| Url::parse(url_str).expect("invalid URL")) + .collect(); + assert_eq!(urls, expected_urls); + // Fourth call + let urls: Vec<_> = provider.n_ordered_routes(5).expect("failed to get urls"); + let expected_urls: Vec = [ + "https://url1.com", + "https://url2.com", + "https://url3.com", + "https://url4.com", + "https://url5.com", + ] + .iter() + .map(|url_str| Url::parse(url_str).expect("invalid URL")) + .collect(); + assert_eq!(urls, expected_urls); + } }