From 4c28c6c6edf11b813059644428b394df20ce8413 Mon Sep 17 00:00:00 2001 From: Jacob Finkelman Date: Fri, 1 Dec 2023 11:31:24 -0500 Subject: [PATCH] perf: more efficient intersection (#157) * refactor: more efficient intersection * more comments --- src/range.rs | 78 +++++++++++++++++++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 25 deletions(-) diff --git a/src/range.rs b/src/range.rs index 91933e61..e3ba82e2 100644 --- a/src/range.rs +++ b/src/range.rs @@ -350,41 +350,69 @@ impl Range { /// Computes the intersection of two sets of versions. pub fn intersection(&self, other: &Self) -> Self { - let mut segments: SmallVec> = SmallVec::empty(); + let mut output: SmallVec> = SmallVec::empty(); let mut left_iter = self.segments.iter().peekable(); let mut right_iter = other.segments.iter().peekable(); - - while let (Some((left_start, left_end)), Some((right_start, right_end))) = - (left_iter.peek(), right_iter.peek()) + // By the definition of intersection any point that is matched by the output + // must have a segment in each of the inputs that it matches. + // Therefore, every segment in the output must be the intersection of a segment from each of the inputs. + // It would be correct to do the "O(n^2)" thing, by computing the intersection of every segment from one input + // with every segment of the other input, and sorting the result. + // We can avoid the sorting by generating our candidate segments with an increasing `end` value. + while let Some(((left_start, left_end), (right_start, right_end))) = + left_iter.peek().zip(right_iter.peek()) { + // The next smallest `end` value is going to come from one of the inputs. + let left_end_is_smaller = match (left_end, right_end) { + (Included(l), Included(r)) + | (Excluded(l), Excluded(r)) + | (Excluded(l), Included(r)) => l <= r, + + (Included(l), Excluded(r)) => l < r, + (_, Unbounded) => true, + (Unbounded, _) => false, + }; + // Now that we are processing `end` we will never have to process any segment smaller than that. + // We can ensure that the input that `end` came from is larger than `end` by advancing it one step. + // `end` is the smaller available input, so we know the other input is already larger than `end`. + // Note: We can call `other_iter.next_if( == end)`, but the ends lining up is rare enough that + // it does not end up being faster in practice. + let (other_start, end) = if left_end_is_smaller { + left_iter.next(); + (right_start, left_end) + } else { + right_iter.next(); + (left_start, right_end) + }; + // `start` will either come from the input `end` came from or the other input, whichever one is larger. + // The intersection is invalid if `start` > `end`. + // But, we already know that the segments in our input are valid. + // So we do not need to check if the `start` from the input `end` came from is smaller then `end`. + // If the `other_start` is larger than end, then the intersection will be invalid. + if !valid_segment(other_start, end) { + // Note: We can call `this_iter.next_if(!valid_segment(other_start, this_end))` in a loop. + // But the checks make it slower for the benchmarked inputs. + continue; + } let start = match (left_start, right_start) { (Included(l), Included(r)) => Included(std::cmp::max(l, r)), (Excluded(l), Excluded(r)) => Excluded(std::cmp::max(l, r)), - (Included(i), Excluded(e)) | (Excluded(e), Included(i)) if i <= e => Excluded(e), - (Included(i), Excluded(e)) | (Excluded(e), Included(i)) if e < i => Included(i), - (s, Unbounded) | (Unbounded, s) => s.as_ref(), - _ => unreachable!(), - } - .cloned(); - let end = match (left_end, right_end) { - (Included(l), Included(r)) => Included(std::cmp::min(l, r)), - (Excluded(l), Excluded(r)) => Excluded(std::cmp::min(l, r)), - - (Included(i), Excluded(e)) | (Excluded(e), Included(i)) if i >= e => Excluded(e), - (Included(i), Excluded(e)) | (Excluded(e), Included(i)) if e > i => Included(i), + (Included(i), Excluded(e)) | (Excluded(e), Included(i)) => { + if i <= e { + Excluded(e) + } else { + Included(i) + } + } (s, Unbounded) | (Unbounded, s) => s.as_ref(), - _ => unreachable!(), - } - .cloned(); - left_iter.next_if(|(_, e)| e == &end); - right_iter.next_if(|(_, e)| e == &end); - if valid_segment(&start, &end) { - segments.push((start, end)) - } + }; + // Now we clone and push a new segment. + // By dealing with references until now we ensure that NO cloning happens when we reject the segment. + output.push((start.cloned(), end.clone())) } - Self { segments }.check_invariants() + Self { segments: output }.check_invariants() } /// Returns a simpler Range that contains the same versions