Skip to content

Commit

Permalink
[search] Increased PreRankerResult batch size to 1000.
Browse files Browse the repository at this point in the history
Signed-off-by: Viktor Govako <[email protected]>
  • Loading branch information
vng committed Aug 27, 2023
1 parent 378b461 commit 5224d3d
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 44 deletions.
12 changes: 4 additions & 8 deletions search/intermediate_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ bool PreRankerResult::LessRankAndPopularity(PreRankerResult const & lhs, PreRank
return lhs.m_info.m_rank > rhs.m_info.m_rank;
if (lhs.m_info.m_popularity != rhs.m_info.m_popularity)
return lhs.m_info.m_popularity > rhs.m_info.m_popularity;

/// @todo Remove this epilog when we will have _enough_ ranks and popularities in data.
return lhs.m_info.m_distanceToPivot < rhs.m_info.m_distanceToPivot;
}

// static
bool PreRankerResult::LessDistance(PreRankerResult const & lhs, PreRankerResult const & rhs)
{
if (lhs.m_info.m_distanceToPivot != rhs.m_info.m_distanceToPivot)
return lhs.m_info.m_distanceToPivot < rhs.m_info.m_distanceToPivot;
return lhs.m_info.m_rank > rhs.m_info.m_rank;
return lhs.m_info.m_distanceToPivot < rhs.m_info.m_distanceToPivot;
}

// static
Expand Down Expand Up @@ -121,11 +121,7 @@ bool PreRankerResult::LessByExactMatch(PreRankerResult const & lhs, PreRankerRes
if (lScore != rScore)
return lScore;

auto const byTokens = CompareByTokensMatch(lhs, rhs);
if (byTokens != 0)
return byTokens == -1;

return LessDistance(lhs, rhs);
return CompareByTokensMatch(lhs, rhs) == -1;
}

bool PreRankerResult::CategoriesComparator::operator()(PreRankerResult const & lhs,
Expand Down
41 changes: 30 additions & 11 deletions search/pre_ranker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <algorithm>
#include <iterator>
#include <random>

namespace search
{
Expand Down Expand Up @@ -55,7 +56,8 @@ void SweepNearbyResults(m2::PointD const & eps, unordered_set<FeatureID> const &
} // namespace

PreRanker::PreRanker(DataSource const & dataSource, Ranker & ranker)
: m_dataSource(dataSource), m_ranker(ranker), m_pivotFeatures(dataSource)
: m_dataSource(dataSource), m_ranker(ranker), m_pivotFeatures(dataSource)
, m_rndSeed(std::random_device()())
{
}

Expand Down Expand Up @@ -157,7 +159,14 @@ template <class CompT, class ContT> class CompareIndices
};
} // namespace

void PreRanker::Filter(bool viewportSearch)
void PreRanker::DbgFindAndLog(std::set<uint32_t> const & ids) const
{
for (auto const & r : m_results)
if (ids.count(r.GetId().m_index) > 0)
LOG(LDEBUG, (r));
}

void PreRanker::Filter()
{
auto const lessForUnique = [](PreRankerResult const & lhs, PreRankerResult const & rhs)
{
Expand All @@ -169,30 +178,40 @@ void PreRanker::Filter(bool viewportSearch)
return PreRankerResult::CompareByTokensMatch(lhs, rhs) == -1;
};

/// @DebugNote
/// Use DbgFindAndLog to check needed ids before and after filtering

base::SortUnique(m_results, lessForUnique, base::EqualsBy(&PreRankerResult::GetId));

if (viewportSearch)
if (m_params.m_viewportSearch)
FilterForViewportSearch();

// Viewport search ends here.
if (m_results.size() <= BatchSize())
return;

vector<size_t> indices(m_results.size());
generate(indices.begin(), indices.end(), [n = 0] () mutable { return n++; });
unordered_set<size_t> filtered;
std::vector<size_t> indices(m_results.size());
std::generate(indices.begin(), indices.end(), [n = 0] () mutable { return n++; });
std::unordered_set<size_t> filtered;

auto const iBeg = indices.begin();
auto const iMiddle = iBeg + BatchSize();
auto const iEnd = indices.end();

nth_element(iBeg, iMiddle, iEnd, CompareIndices(&PreRankerResult::LessDistance, m_results));
std::nth_element(iBeg, iMiddle, iEnd, CompareIndices(&PreRankerResult::LessDistance, m_results));
filtered.insert(iBeg, iMiddle);

if (!m_params.m_categorialRequest)
{
nth_element(iBeg, iMiddle, iEnd, CompareIndices(&PreRankerResult::LessRankAndPopularity, m_results));
std::nth_element(iBeg, iMiddle, iEnd, CompareIndices(&PreRankerResult::LessRankAndPopularity, m_results));
filtered.insert(iBeg, iMiddle);
nth_element(iBeg, iMiddle, iEnd, CompareIndices(&PreRankerResult::LessByExactMatch, m_results));

// Shuffle to give a chance to far results, not only closest ones (see above).
// Search is not stable in rare cases, but we avoid increasing m_everywhereBatchSize.
/// @todo Move up, when we will have _enough_ ranks and popularities.
std::shuffle(iBeg, iEnd, std::mt19937(m_rndSeed));

std::nth_element(iBeg, iMiddle, iEnd, CompareIndices(&PreRankerResult::LessByExactMatch, m_results));
filtered.insert(iBeg, iMiddle);
}
else
Expand All @@ -206,7 +225,7 @@ void PreRanker::Filter(bool viewportSearch)
2 * kPedestrianRadiusMeters;
comparator.m_viewport = m_params.m_viewport;

nth_element(iBeg, iMiddle, iEnd, CompareIndices(comparator, m_results));
std::nth_element(iBeg, iMiddle, iEnd, CompareIndices(comparator, m_results));
filtered.insert(iBeg, iMiddle);
}

Expand All @@ -221,7 +240,7 @@ void PreRanker::UpdateResults(bool lastUpdate)
{
FilterRelaxedResults(lastUpdate);
FillMissingFieldsInPreResults();
Filter(m_params.m_viewportSearch);
Filter();
m_numSentResults += m_results.size();
m_ranker.AddPreRankerResults(std::move(m_results));
m_results.clear();
Expand Down
21 changes: 12 additions & 9 deletions search/pre_ranker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ class PreRanker

int m_scale = 0;

// Batch size for Everywhere search mode. For viewport search we limit search results number
// with SweepNearbyResults.
size_t m_everywhereBatchSize = 100;
// Batch size for Everywhere search mode.
// For viewport search we limit search results number with SweepNearbyResults.
// Increased to 1K, no problem to read 1-2K Features per search now, but the quality is much better.
/// @see BA_SanMartin test.
size_t m_everywhereBatchSize = 1000;

// The maximum total number of results to be emitted in all batches.
size_t m_limit = 0;
Expand Down Expand Up @@ -76,11 +78,6 @@ class PreRanker
m_haveFullyMatchedResult = true;
}

// Computes missing fields for all pre-results.
void FillMissingFieldsInPreResults();

void Filter(bool viewportSearch);

// Emit a new batch of results up the pipeline (i.e. to ranker).
// Use |lastUpdate| to indicate that no more results will be added.
void UpdateResults(bool lastUpdate);
Expand Down Expand Up @@ -137,8 +134,12 @@ class PreRanker
void ClearCaches();

private:
void FilterForViewportSearch();
// Computes missing fields for all pre-results.
void FillMissingFieldsInPreResults();
void DbgFindAndLog(std::set<uint32_t> const & ids) const;

void FilterForViewportSearch();
void Filter();
void FilterRelaxedResults(bool lastUpdate);

DataSource const & m_dataSource;
Expand Down Expand Up @@ -166,6 +167,8 @@ class PreRanker
std::unordered_set<FeatureID> m_prevEmit;
/// @}

unsigned m_rndSeed;

DISALLOW_COPY_AND_MOVE(PreRanker);
};
} // namespace search
13 changes: 10 additions & 3 deletions search/ranker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -748,9 +748,14 @@ void Ranker::UpdateResults(bool lastUpdate)
}
else
{
/// @note Here is _reverse_ order sorting, because bigger is better.
sort(m_tentativeResults.rbegin(), m_tentativeResults.rend(),
base::LessBy(&RankerResult::GetLinearModelRank));
// Can get same Town features (from World) when searching in many MWMs.
base::SortUnique(m_tentativeResults,
[](RankerResult const & r1, RankerResult const & r2)
{
// Expect that linear rank is equal for the same features.
return r1.GetLinearModelRank() > r2.GetLinearModelRank();
},
base::EqualsBy(&RankerResult::GetID));

ProcessSuggestions(m_tentativeResults);
}
Expand Down Expand Up @@ -824,6 +829,8 @@ void Ranker::LoadCountriesTree() { m_regionInfoGetter.LoadCountriesTree(); }

void Ranker::MakeRankerResults()
{
LOG(LDEBUG, ("PreRankerResults number =", m_preRankerResults.size()));

bool const isViewportMode = m_geocoderParams.m_mode == Mode::Viewport;

RankerResultMaker maker(*this, m_dataSource, m_infoGetter, m_reverseGeocoder, m_geocoderParams);
Expand Down
14 changes: 4 additions & 10 deletions search/search_integration_tests/pre_ranker_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,6 @@ UNIT_CLASS_TEST(PreRankerTest, Smoke)
m2::PointD const kPivot(0, 0);
m2::RectD const kViewport(-5, -5, 5, 5);

/// @todo Well, I'm not sure that 50 results will have unique distances to pivot.
/// 7x7 grid is 49, so potentially it can be 51 (north and south) or (east and west).
/// But we should consider circle (ellipse) around pivot and I can't say,
/// how it goes in meters radius on integer mercator grid.
size_t constexpr kBatchSize = 50;

vector<TestPOI> pois;
for (int x = -5; x <= 5; ++x)
{
Expand All @@ -111,7 +105,7 @@ UNIT_CLASS_TEST(PreRankerTest, Smoke)
}
}

TEST_LESS(kBatchSize, pois.size(), ());
size_t const batchSize = pois.size() / 2;

auto mwmId = BuildCountry("Cafeland", [&](TestMwmBuilder & builder)
{
Expand All @@ -133,7 +127,7 @@ UNIT_CLASS_TEST(PreRankerTest, Smoke)
params.m_viewport = kViewport;
params.m_accuratePivotCenter = kPivot;
params.m_scale = scales::GetUpperScale();
params.m_everywhereBatchSize = kBatchSize;
params.m_everywhereBatchSize = batchSize;
params.m_limit = pois.size();
params.m_viewportSearch = false;
preRanker.Init(params);
Expand All @@ -159,8 +153,8 @@ UNIT_CLASS_TEST(PreRankerTest, Smoke)
TEST(ranker.Finished(), ());

size_t const count = results.size();
// See todo comment above for details.
TEST(count == kBatchSize || count == kBatchSize + 1, (count));
// Depends on std::shuffle, but lets keep 6% threshold.
TEST(count > batchSize*1.06 && count < batchSize*1.94, (count));

vector<bool> checked(pois.size());
for (size_t i = 0; i < count; ++i)
Expand Down
34 changes: 31 additions & 3 deletions search/search_quality/search_quality_tests/real_mwm_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ class MwmTestsFixture : public search::tests_support::SearchTestBase
}
}

static size_t CountClassifType(Range const & results, uint32_t type)
{
return std::count_if(results.begin(), results.end(), [type](search::Result const & r)
{
return EqualClassifType(r.GetFeatureType(), type);
});
}

static void NameStartsWith(Range const & results, base::StringIL const & prefixes)
{
for (auto const & r : results)
Expand Down Expand Up @@ -335,7 +343,10 @@ UNIT_CLASS_TEST(MwmTestsFixture, Barcelona_Carrers)
auto const & results = request->Results();
TEST_GREATER(results.size(), kTopPoiResultsCount, ());

Range const range(results, 0, 4);
// - First 2 streets in Barcelona (1-2 km)
// - Next streets in Badalona, Barbera, Sadabel, ... (20-30 km)
// - Again 2 _minor_ footways in Barcelona (1-2 km)
Range const range(results, 0, 2);
EqualClassifType(range, GetClassifTypes({{"highway"}}));
CenterInRect(range, {2.1651583, 41.3899995, 2.1863021, 41.4060494});
}
Expand Down Expand Up @@ -473,6 +484,7 @@ UNIT_CLASS_TEST(MwmTestsFixture, French_StopWord_Category)
UNIT_CLASS_TEST(MwmTestsFixture, Street_BusStop)
{
// Buenos Aires
// Also should download Argentina_Santa Fe.
ms::LatLon const center(-34.60655, -58.43566);
SetViewportAndLoadMaps(center);

Expand All @@ -481,9 +493,10 @@ UNIT_CLASS_TEST(MwmTestsFixture, Street_BusStop)
auto const & results = request->Results();
TEST_GREATER(results.size(), kTopPoiResultsCount, ());

// Top results are Hotel and Street (sometimes bus stop).
// Top results are Hotel, Shop and Street.
// Full Match street (20 km) is better than Full Prefix bus stop (1 km).
Range const range(results);
EqualClassifType(range, GetClassifTypes({{"tourism", "hotel"}, {"highway", "bus_stop"}, {"highway", "residential"}}));
EqualClassifType(range, GetClassifTypes({{"tourism", "hotel"}, {"shop"}, {"highway", "residential"}}));
}

{
Expand Down Expand Up @@ -824,4 +837,19 @@ UNIT_CLASS_TEST(MwmTestsFixture, BA_LasHeras)
}
}

UNIT_CLASS_TEST(MwmTestsFixture, BA_SanMartin)
{
// Buenos Aires (Palermo)
ms::LatLon const center(-34.5801392, -58.415764);
SetViewportAndLoadMaps(center);

{
auto request = MakeRequest("San Martin");
auto const & results = request->Results();
size_t constexpr kResultsCount = 12;
TEST_GREATER(results.size(), kResultsCount, ());
TEST_GREATER(CountClassifType(Range(results, 0, kResultsCount),
classif().GetTypeByPath({"railway", "station"})), 2, ());
}
}
} // namespace real_mwm_tests

0 comments on commit 5224d3d

Please sign in to comment.