diff --git a/src/search/merge_and_shrink/merge_scoring_function.cc b/src/search/merge_and_shrink/merge_scoring_function.cc index 83dd2dfd18..4e77d590b2 100644 --- a/src/search/merge_and_shrink/merge_scoring_function.cc +++ b/src/search/merge_and_shrink/merge_scoring_function.cc @@ -8,8 +8,9 @@ using namespace std; namespace merge_and_shrink { -MergeScoringFunction::MergeScoringFunction() - : initialized(false) { +MergeScoringFunction::MergeScoringFunction(const plugins::Options &options) + : use_caching(options.get("use_caching")), + initialized(false) { } void MergeScoringFunction::dump_options(utils::LogProxy &log) const { @@ -20,6 +21,14 @@ void MergeScoringFunction::dump_options(utils::LogProxy &log) const { } } +void add_merge_scoring_function_options_to_feature(plugins::Feature &feature) { + feature.add_option( + "use_caching", + "Cache scores for merge candidates. Currently only supported by the " + "MIASM scoring function.", + "false"); +} + static class MergeScoringFunctionCategoryPlugin : public plugins::TypedCategoryPlugin { public: MergeScoringFunctionCategoryPlugin() : TypedCategoryPlugin("MergeScoringFunction") { diff --git a/src/search/merge_and_shrink/merge_scoring_function.h b/src/search/merge_and_shrink/merge_scoring_function.h index 7a2e8c4382..3bea5e2934 100644 --- a/src/search/merge_and_shrink/merge_scoring_function.h +++ b/src/search/merge_and_shrink/merge_scoring_function.h @@ -1,28 +1,47 @@ #ifndef MERGE_AND_SHRINK_MERGE_SCORING_FUNCTION_H #define MERGE_AND_SHRINK_MERGE_SCORING_FUNCTION_H +#include #include +#include #include class TaskProxy; +namespace plugins { +class Feature; +class Options; +} + namespace utils { class LogProxy; } namespace merge_and_shrink { class FactoredTransitionSystem; +struct MergeCandidate { + int id; + int index1; + int index2; + MergeCandidate(int id, int index1, int index2) + : id(id), index1(index1), index2(index2) { + } +}; + class MergeScoringFunction { protected: + const bool use_caching; bool initialized; + std::unordered_map cached_scores; + virtual std::string name() const = 0; virtual void dump_function_specific_options(utils::LogProxy &) const {} public: - MergeScoringFunction(); + explicit MergeScoringFunction(const plugins::Options &options); virtual ~MergeScoringFunction() = default; virtual std::vector compute_scores( const FactoredTransitionSystem &fts, - const std::vector> &merge_candidates) = 0; + const std::vector> &merge_candidates) = 0; virtual bool requires_init_distances() const = 0; virtual bool requires_goal_distances() const = 0; @@ -33,6 +52,8 @@ class MergeScoringFunction { void dump_options(utils::LogProxy &log) const; }; + +extern void add_merge_scoring_function_options_to_feature(plugins::Feature &feature); } #endif diff --git a/src/search/merge_and_shrink/merge_scoring_function_dfp.cc b/src/search/merge_and_shrink/merge_scoring_function_dfp.cc index fa259bf93a..dce5b803c3 100644 --- a/src/search/merge_and_shrink/merge_scoring_function_dfp.cc +++ b/src/search/merge_and_shrink/merge_scoring_function_dfp.cc @@ -13,6 +13,11 @@ using namespace std; namespace merge_and_shrink { +MergeScoringFunctionDFP::MergeScoringFunctionDFP( + const plugins::Options &options) + : MergeScoringFunction(options) { +} + vector MergeScoringFunctionDFP::compute_label_ranks( const FactoredTransitionSystem &fts, int index) const { const TransitionSystem &ts = fts.get_transition_system(index); @@ -60,7 +65,7 @@ vector MergeScoringFunctionDFP::compute_label_ranks( vector MergeScoringFunctionDFP::compute_scores( const FactoredTransitionSystem &fts, - const vector> &merge_candidates) { + const vector> &merge_candidates) { int num_ts = fts.get_size(); vector> transition_system_label_ranks(num_ts); @@ -68,9 +73,9 @@ vector MergeScoringFunctionDFP::compute_scores( scores.reserve(merge_candidates.size()); // Go over all pairs of transition systems and compute their weight. - for (pair merge_candidate : merge_candidates) { - int ts_index1 = merge_candidate.first; - int ts_index2 = merge_candidate.second; + for (const auto &merge_candidate : merge_candidates) { + int ts_index1 = merge_candidate->index1; + int ts_index2 = merge_candidate->index2; vector &label_ranks1 = transition_system_label_ranks[ts_index1]; if (label_ranks1.empty()) { @@ -129,10 +134,7 @@ class MergeScoringFunctionDFPFeature : public plugins::TypedFeature create_component(const plugins::Options &, const utils::Context &) const override { - return make_shared(); + add_merge_scoring_function_options_to_feature(*this); } }; diff --git a/src/search/merge_and_shrink/merge_scoring_function_dfp.h b/src/search/merge_and_shrink/merge_scoring_function_dfp.h index a5a1b9de18..bf68ad8e7b 100644 --- a/src/search/merge_and_shrink/merge_scoring_function_dfp.h +++ b/src/search/merge_and_shrink/merge_scoring_function_dfp.h @@ -10,11 +10,11 @@ class MergeScoringFunctionDFP : public MergeScoringFunction { protected: virtual std::string name() const override; public: - MergeScoringFunctionDFP() = default; + explicit MergeScoringFunctionDFP(const plugins::Options &options); virtual ~MergeScoringFunctionDFP() override = default; virtual std::vector compute_scores( const FactoredTransitionSystem &fts, - const std::vector> &merge_candidates) override; + const std::vector> &merge_candidates) override; virtual bool requires_init_distances() const override { return false; diff --git a/src/search/merge_and_shrink/merge_scoring_function_goal_relevance.cc b/src/search/merge_and_shrink/merge_scoring_function_goal_relevance.cc index f92fdc5535..7eacde822b 100644 --- a/src/search/merge_and_shrink/merge_scoring_function_goal_relevance.cc +++ b/src/search/merge_and_shrink/merge_scoring_function_goal_relevance.cc @@ -9,9 +9,14 @@ using namespace std; namespace merge_and_shrink { +MergeScoringFunctionGoalRelevance::MergeScoringFunctionGoalRelevance( + const plugins::Options &options) + : MergeScoringFunction(options) { +} + vector MergeScoringFunctionGoalRelevance::compute_scores( const FactoredTransitionSystem &fts, - const vector> &merge_candidates) { + const vector> &merge_candidates) { int num_ts = fts.get_size(); vector goal_relevant(num_ts, false); for (int ts_index : fts) { @@ -23,9 +28,9 @@ vector MergeScoringFunctionGoalRelevance::compute_scores( vector scores; scores.reserve(merge_candidates.size()); - for (pair merge_candidate : merge_candidates) { - int ts_index1 = merge_candidate.first; - int ts_index2 = merge_candidate.second; + for (const auto &merge_candidate : merge_candidates) { + int ts_index1 = merge_candidate->index1; + int ts_index2 = merge_candidate->index2; int score = INF; if (goal_relevant[ts_index1] || goal_relevant[ts_index2]) { score = 0; @@ -48,10 +53,7 @@ class MergeScoringFunctionGoalRelevanceFeature : public plugins::TypedFeature create_component(const plugins::Options &, const utils::Context &) const override { - return make_shared(); + add_merge_scoring_function_options_to_feature(*this); } }; diff --git a/src/search/merge_and_shrink/merge_scoring_function_goal_relevance.h b/src/search/merge_and_shrink/merge_scoring_function_goal_relevance.h index 64d00f362d..dd1b3d2d4c 100644 --- a/src/search/merge_and_shrink/merge_scoring_function_goal_relevance.h +++ b/src/search/merge_and_shrink/merge_scoring_function_goal_relevance.h @@ -8,11 +8,11 @@ class MergeScoringFunctionGoalRelevance : public MergeScoringFunction { protected: virtual std::string name() const override; public: - MergeScoringFunctionGoalRelevance() = default; + MergeScoringFunctionGoalRelevance(const plugins::Options &options); virtual ~MergeScoringFunctionGoalRelevance() override = default; virtual std::vector compute_scores( const FactoredTransitionSystem &fts, - const std::vector> &merge_candidates) override; + const std::vector> &merge_candidates) override; virtual bool requires_init_distances() const override { return false; diff --git a/src/search/merge_and_shrink/merge_scoring_function_miasm.cc b/src/search/merge_and_shrink/merge_scoring_function_miasm.cc index f1f10dc113..57cfdd40f1 100644 --- a/src/search/merge_and_shrink/merge_scoring_function_miasm.cc +++ b/src/search/merge_and_shrink/merge_scoring_function_miasm.cc @@ -16,7 +16,8 @@ using namespace std; namespace merge_and_shrink { MergeScoringFunctionMIASM::MergeScoringFunctionMIASM( const plugins::Options &options) - : shrink_strategy(options.get>("shrink_strategy")), + : MergeScoringFunction(options), + shrink_strategy(options.get>("shrink_strategy")), max_states(options.get("max_states")), max_states_before_merge(options.get("max_states_before_merge")), shrink_threshold_before_merge(options.get("threshold_before_merge")), @@ -25,43 +26,52 @@ MergeScoringFunctionMIASM::MergeScoringFunctionMIASM( vector MergeScoringFunctionMIASM::compute_scores( const FactoredTransitionSystem &fts, - const vector> &merge_candidates) { + const vector> &merge_candidates) { vector scores; scores.reserve(merge_candidates.size()); - for (pair merge_candidate : merge_candidates) { - int index1 = merge_candidate.first; - int index2 = merge_candidate.second; - unique_ptr product = shrink_before_merge_externally( - fts, - index1, - index2, - *shrink_strategy, - max_states, - max_states_before_merge, - shrink_threshold_before_merge, - silent_log); + for (const auto &merge_candidate : merge_candidates) { + double score; + int id = merge_candidate->id; + if (use_caching && cached_scores.count(id)) { + score = cached_scores[id]; + } else { + int index1 = merge_candidate->index1; + int index2 = merge_candidate->index2; + unique_ptr product = shrink_before_merge_externally( + fts, + index1, + index2, + *shrink_strategy, + max_states, + max_states_before_merge, + shrink_threshold_before_merge, + silent_log); - // Compute distances for the product and count the alive states. - unique_ptr distances = utils::make_unique_ptr(*product); - const bool compute_init_distances = true; - const bool compute_goal_distances = true; - distances->compute_distances(compute_init_distances, compute_goal_distances, silent_log); - int num_states = product->get_size(); - int alive_states_count = 0; - for (int state = 0; state < num_states; ++state) { - if (distances->get_init_distance(state) != INF && - distances->get_goal_distance(state) != INF) { - ++alive_states_count; + // Compute distances for the product and count the alive states. + unique_ptr distances = utils::make_unique_ptr(*product); + const bool compute_init_distances = true; + const bool compute_goal_distances = true; + distances->compute_distances(compute_init_distances, compute_goal_distances, silent_log); + int num_states = product->get_size(); + int alive_states_count = 0; + for (int state = 0; state < num_states; ++state) { + if (distances->get_init_distance(state) != INF && + distances->get_goal_distance(state) != INF) { + ++alive_states_count; + } } - } - /* - Compute the score as the ratio of alive states of the product - compared to the number of states of the full product. - */ - assert(num_states); - double score = static_cast(alive_states_count) / - static_cast(num_states); + /* + Compute the score as the ratio of alive states of the product + compared to the number of states of the full product. + */ + assert(num_states); + score = static_cast(alive_states_count) / + static_cast(num_states); + if (use_caching) { + cached_scores[id] = score; + } + } scores.push_back(score); } return scores; @@ -127,6 +137,7 @@ class MergeScoringFunctionMIASMFeature : public plugins::TypedFeature create_component(const plugins::Options &options, const utils::Context &context) const override { diff --git a/src/search/merge_and_shrink/merge_scoring_function_miasm.h b/src/search/merge_and_shrink/merge_scoring_function_miasm.h index e224d7b114..1c2e376449 100644 --- a/src/search/merge_and_shrink/merge_scoring_function_miasm.h +++ b/src/search/merge_and_shrink/merge_scoring_function_miasm.h @@ -22,7 +22,7 @@ class MergeScoringFunctionMIASM : public MergeScoringFunction { virtual ~MergeScoringFunctionMIASM() override = default; virtual std::vector compute_scores( const FactoredTransitionSystem &fts, - const std::vector> &merge_candidates) override; + const std::vector> &merge_candidates) override; virtual bool requires_init_distances() const override { return true; diff --git a/src/search/merge_and_shrink/merge_scoring_function_single_random.cc b/src/search/merge_and_shrink/merge_scoring_function_single_random.cc index 90e0eaf9bf..1c405641d0 100644 --- a/src/search/merge_and_shrink/merge_scoring_function_single_random.cc +++ b/src/search/merge_and_shrink/merge_scoring_function_single_random.cc @@ -14,13 +14,14 @@ using namespace std; namespace merge_and_shrink { MergeScoringFunctionSingleRandom::MergeScoringFunctionSingleRandom( const plugins::Options &options) - : random_seed(options.get("random_seed")), + : MergeScoringFunction(options), + random_seed(options.get("random_seed")), rng(utils::parse_rng_from_options(options)) { } vector MergeScoringFunctionSingleRandom::compute_scores( const FactoredTransitionSystem &, - const vector> &merge_candidates) { + const vector> &merge_candidates) { int chosen_index = rng->random(merge_candidates.size()); vector scores; scores.reserve(merge_candidates.size()); @@ -53,7 +54,7 @@ class MergeScoringFunctionSingleRandomFeature : public plugins::TypedFeature compute_scores( const FactoredTransitionSystem &fts, - const std::vector> &merge_candidates) override; + const std::vector> &merge_candidates) override; virtual bool requires_init_distances() const override { return false; diff --git a/src/search/merge_and_shrink/merge_scoring_function_total_order.cc b/src/search/merge_and_shrink/merge_scoring_function_total_order.cc index 8b26a37db6..b721cec4cb 100644 --- a/src/search/merge_and_shrink/merge_scoring_function_total_order.cc +++ b/src/search/merge_and_shrink/merge_scoring_function_total_order.cc @@ -18,7 +18,8 @@ using namespace std; namespace merge_and_shrink { MergeScoringFunctionTotalOrder::MergeScoringFunctionTotalOrder( const plugins::Options &options) - : atomic_ts_order(options.get("atomic_ts_order")), + : MergeScoringFunction(options), + atomic_ts_order(options.get("atomic_ts_order")), product_ts_order(options.get("product_ts_order")), atomic_before_product(options.get("atomic_before_product")), random_seed(options.get("random_seed")), @@ -27,15 +28,15 @@ MergeScoringFunctionTotalOrder::MergeScoringFunctionTotalOrder( vector MergeScoringFunctionTotalOrder::compute_scores( const FactoredTransitionSystem &, - const vector> &merge_candidates) { + const vector> &merge_candidates) { assert(initialized); vector scores; scores.reserve(merge_candidates.size()); for (size_t candidate_index = 0; candidate_index < merge_candidates.size(); ++candidate_index) { - pair merge_candidate = merge_candidates[candidate_index]; - int ts_index1 = merge_candidate.first; - int ts_index2 = merge_candidate.second; + const auto &merge_candidate = merge_candidates[candidate_index]; + int ts_index1 = merge_candidate->index1; + int ts_index2 = merge_candidate->index2; for (size_t merge_candidate_order_index = 0; merge_candidate_order_index < merge_candidate_order.size(); ++merge_candidate_order_index) { @@ -199,6 +200,7 @@ class MergeScoringFunctionTotalOrderFeature : public plugins::TypedFeature compute_scores( const FactoredTransitionSystem &fts, - const std::vector> &merge_candidates) override; + const std::vector> &merge_candidates) override; virtual void initialize(const TaskProxy &task_proxy) override; static void add_options_to_feature(plugins::Feature &feature); diff --git a/src/search/merge_and_shrink/merge_selector.cc b/src/search/merge_and_shrink/merge_selector.cc index 4bdf01d859..19983d2f48 100644 --- a/src/search/merge_and_shrink/merge_selector.cc +++ b/src/search/merge_and_shrink/merge_selector.cc @@ -1,8 +1,11 @@ #include "merge_selector.h" #include "factored_transition_system.h" +#include "merge_scoring_function.h" +#include "../task_proxy.h" #include "../plugins/plugin.h" +#include "../utils/collections.h" #include "../utils/logging.h" #include @@ -11,17 +14,29 @@ using namespace std; namespace merge_and_shrink { -vector> MergeSelector::compute_merge_candidates( +shared_ptr MergeSelector::get_candidate( + int index1, int index2) { + assert(utils::in_bounds(index1, merge_candidates_by_indices)); + assert(utils::in_bounds(index2, merge_candidates_by_indices[index1])); + if (merge_candidates_by_indices[index1][index2] == nullptr) { + merge_candidates_by_indices[index1][index2] = + make_shared(num_candidates, index1, index2); + ++num_candidates; + } + return merge_candidates_by_indices[index1][index2]; +} + +vector> MergeSelector::compute_merge_candidates( const FactoredTransitionSystem &fts, - const vector &indices_subset) const { - vector> merge_candidates; + const vector &indices_subset) { + vector> merge_candidates; if (indices_subset.empty()) { for (int ts_index1 = 0; ts_index1 < fts.get_size(); ++ts_index1) { if (fts.is_active(ts_index1)) { for (int ts_index2 = ts_index1 + 1; ts_index2 < fts.get_size(); ++ts_index2) { if (fts.is_active(ts_index2)) { - merge_candidates.emplace_back(ts_index1, ts_index2); + merge_candidates.push_back(get_candidate(ts_index1, ts_index2)); } } } @@ -34,13 +49,21 @@ vector> MergeSelector::compute_merge_candidates( for (size_t j = i + 1; j < indices_subset.size(); ++j) { int ts_index2 = indices_subset[j]; assert(fts.is_active(ts_index2)); - merge_candidates.emplace_back(ts_index1, ts_index2); + merge_candidates.push_back(get_candidate(ts_index1, ts_index2)); } } } return merge_candidates; } +void MergeSelector::initialize(const TaskProxy &task_proxy) { + int num_variables = task_proxy.get_variables().size(); + int max_factor_index = 2 * num_variables - 1; + merge_candidates_by_indices.resize( + max_factor_index, + vector>(max_factor_index, nullptr)); +} + void MergeSelector::dump_options(utils::LogProxy &log) const { if (log.is_at_least_normal()) { log << "Merge selector options:" << endl; diff --git a/src/search/merge_and_shrink/merge_selector.h b/src/search/merge_and_shrink/merge_selector.h index 3cc56f5247..53f052afbc 100644 --- a/src/search/merge_and_shrink/merge_selector.h +++ b/src/search/merge_and_shrink/merge_selector.h @@ -1,6 +1,7 @@ #ifndef MERGE_AND_SHRINK_MERGE_SELECTOR_H #define MERGE_AND_SHRINK_MERGE_SELECTOR_H +#include #include #include @@ -12,20 +13,25 @@ class LogProxy; namespace merge_and_shrink { class FactoredTransitionSystem; +struct MergeCandidate; class MergeSelector { protected: + std::vector>> merge_candidates_by_indices; + int num_candidates; + + std::shared_ptr get_candidate(int index1, int index2); virtual std::string name() const = 0; virtual void dump_selector_specific_options(utils::LogProxy &) const {} - std::vector> compute_merge_candidates( + std::vector> compute_merge_candidates( const FactoredTransitionSystem &fts, - const std::vector &indices_subset) const; + const std::vector &indices_subset); public: MergeSelector() = default; virtual ~MergeSelector() = default; virtual std::pair select_merge( const FactoredTransitionSystem &fts, - const std::vector &indices_subset = std::vector()) const = 0; - virtual void initialize(const TaskProxy &task_proxy) = 0; + const std::vector &indices_subset = std::vector()) = 0; + virtual void initialize(const TaskProxy &task_proxy); void dump_options(utils::LogProxy &log) const; virtual bool requires_init_distances() const = 0; virtual bool requires_goal_distances() const = 0; diff --git a/src/search/merge_and_shrink/merge_selector_score_based_filtering.cc b/src/search/merge_and_shrink/merge_selector_score_based_filtering.cc index 4b6e96206a..7795de4989 100644 --- a/src/search/merge_and_shrink/merge_selector_score_based_filtering.cc +++ b/src/search/merge_and_shrink/merge_selector_score_based_filtering.cc @@ -17,9 +17,9 @@ MergeSelectorScoreBasedFiltering::MergeSelectorScoreBasedFiltering( "scoring_functions")) { } -vector> MergeSelectorScoreBasedFiltering::get_remaining_candidates( - const vector> &merge_candidates, - const vector &scores) const { +static vector> get_remaining_candidates( + const vector> &merge_candidates, + const vector &scores) { assert(merge_candidates.size() == scores.size()); double best_score = INF; for (double score : scores) { @@ -28,7 +28,7 @@ vector> MergeSelectorScoreBasedFiltering::get_remaining_candidate } } - vector> result; + vector> result; for (size_t i = 0; i < scores.size(); ++i) { if (scores[i] == best_score) { result.push_back(merge_candidates[i]); @@ -39,8 +39,8 @@ vector> MergeSelectorScoreBasedFiltering::get_remaining_candidate pair MergeSelectorScoreBasedFiltering::select_merge( const FactoredTransitionSystem &fts, - const vector &indices_subset) const { - vector> merge_candidates = + const vector &indices_subset) { + vector> merge_candidates = compute_merge_candidates(fts, indices_subset); for (const shared_ptr &scoring_function : @@ -60,10 +60,11 @@ pair MergeSelectorScoreBasedFiltering::select_merge( utils::exit_with(utils::ExitCode::SEARCH_CRITICAL_ERROR); } - return merge_candidates.front(); + return make_pair(merge_candidates.front()->index1, merge_candidates.front()->index2); } void MergeSelectorScoreBasedFiltering::initialize(const TaskProxy &task_proxy) { + MergeSelector::initialize(task_proxy); for (shared_ptr &scoring_function : merge_scoring_functions) { scoring_function->initialize(task_proxy); diff --git a/src/search/merge_and_shrink/merge_selector_score_based_filtering.h b/src/search/merge_and_shrink/merge_selector_score_based_filtering.h index ebf01351be..58efbc1b4b 100644 --- a/src/search/merge_and_shrink/merge_selector_score_based_filtering.h +++ b/src/search/merge_and_shrink/merge_selector_score_based_filtering.h @@ -3,9 +3,6 @@ #include "merge_selector.h" -#include "merge_scoring_function.h" - -#include #include namespace plugins { @@ -13,12 +10,9 @@ class Options; } namespace merge_and_shrink { +class MergeScoringFunction; class MergeSelectorScoreBasedFiltering : public MergeSelector { std::vector> merge_scoring_functions; - - std::vector> get_remaining_candidates( - const std::vector> &merge_candidates, - const std::vector &scores) const; protected: virtual std::string name() const override; virtual void dump_selector_specific_options(utils::LogProxy &log) const override; @@ -27,7 +21,7 @@ class MergeSelectorScoreBasedFiltering : public MergeSelector { virtual ~MergeSelectorScoreBasedFiltering() override = default; virtual std::pair select_merge( const FactoredTransitionSystem &fts, - const std::vector &indices_subset = std::vector()) const override; + const std::vector &indices_subset = std::vector()) override; virtual void initialize(const TaskProxy &task_proxy) override; virtual bool requires_init_distances() const override; virtual bool requires_goal_distances() const override;