Skip to content

Commit

Permalink
add caching option
Browse files Browse the repository at this point in the history
  • Loading branch information
Silvan Sievers committed Jul 17, 2023
1 parent 9831694 commit 5044d41
Show file tree
Hide file tree
Showing 16 changed files with 164 additions and 92 deletions.
13 changes: 11 additions & 2 deletions src/search/merge_and_shrink/merge_scoring_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
using namespace std;

namespace merge_and_shrink {
MergeScoringFunction::MergeScoringFunction()
: initialized(false) {
MergeScoringFunction::MergeScoringFunction(const plugins::Options &options)
: use_caching(options.get<bool>("use_caching")),
initialized(false) {
}

void MergeScoringFunction::dump_options(utils::LogProxy &log) const {
Expand All @@ -20,6 +21,14 @@ void MergeScoringFunction::dump_options(utils::LogProxy &log) const {
}
}

void add_merge_scoring_function_options_to_feature(plugins::Feature &feature) {
feature.add_option<bool>(
"use_caching",
"Cache scores for merge candidates. Currently only supported by the "
"MIASM scoring function.",
"false");
}

static class MergeScoringFunctionCategoryPlugin : public plugins::TypedCategoryPlugin<MergeScoringFunction> {
public:
MergeScoringFunctionCategoryPlugin() : TypedCategoryPlugin("MergeScoringFunction") {
Expand Down
25 changes: 23 additions & 2 deletions src/search/merge_and_shrink/merge_scoring_function.h
Original file line number Diff line number Diff line change
@@ -1,28 +1,47 @@
#ifndef MERGE_AND_SHRINK_MERGE_SCORING_FUNCTION_H
#define MERGE_AND_SHRINK_MERGE_SCORING_FUNCTION_H

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

class TaskProxy;

namespace plugins {
class Feature;
class Options;
}

namespace utils {
class LogProxy;
}

namespace merge_and_shrink {
class FactoredTransitionSystem;
struct MergeCandidate {
int id;
int index1;
int index2;
MergeCandidate(int id, int index1, int index2)
: id(id), index1(index1), index2(index2) {
}
};

class MergeScoringFunction {
protected:
const bool use_caching;
bool initialized;
std::unordered_map<int, double> cached_scores;

virtual std::string name() const = 0;
virtual void dump_function_specific_options(utils::LogProxy &) const {}
public:
MergeScoringFunction();
explicit MergeScoringFunction(const plugins::Options &options);
virtual ~MergeScoringFunction() = default;
virtual std::vector<double> compute_scores(
const FactoredTransitionSystem &fts,
const std::vector<std::pair<int, int>> &merge_candidates) = 0;
const std::vector<std::shared_ptr<MergeCandidate>> &merge_candidates) = 0;
virtual bool requires_init_distances() const = 0;
virtual bool requires_goal_distances() const = 0;

Expand All @@ -33,6 +52,8 @@ class MergeScoringFunction {

void dump_options(utils::LogProxy &log) const;
};

extern void add_merge_scoring_function_options_to_feature(plugins::Feature &feature);
}

#endif
18 changes: 10 additions & 8 deletions src/search/merge_and_shrink/merge_scoring_function_dfp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
using namespace std;

namespace merge_and_shrink {
MergeScoringFunctionDFP::MergeScoringFunctionDFP(
const plugins::Options &options)
: MergeScoringFunction(options) {
}

vector<int> MergeScoringFunctionDFP::compute_label_ranks(
const FactoredTransitionSystem &fts, int index) const {
const TransitionSystem &ts = fts.get_transition_system(index);
Expand Down Expand Up @@ -60,17 +65,17 @@ vector<int> MergeScoringFunctionDFP::compute_label_ranks(

vector<double> MergeScoringFunctionDFP::compute_scores(
const FactoredTransitionSystem &fts,
const vector<pair<int, int>> &merge_candidates) {
const vector<shared_ptr<MergeCandidate>> &merge_candidates) {
int num_ts = fts.get_size();

vector<vector<int>> transition_system_label_ranks(num_ts);
vector<double> scores;
scores.reserve(merge_candidates.size());

// Go over all pairs of transition systems and compute their weight.
for (pair<int, int> merge_candidate : merge_candidates) {
int ts_index1 = merge_candidate.first;
int ts_index2 = merge_candidate.second;
for (const auto &merge_candidate : merge_candidates) {
int ts_index1 = merge_candidate->index1;
int ts_index2 = merge_candidate->index2;

vector<int> &label_ranks1 = transition_system_label_ranks[ts_index1];
if (label_ranks1.empty()) {
Expand Down Expand Up @@ -129,10 +134,7 @@ class MergeScoringFunctionDFPFeature : public plugins::TypedFeature<MergeScoring
"atomic_before_product=true)])),shrink_strategy=shrink_bisimulation("
"greedy=false),label_reduction=exact(before_shrinking=true,"
"before_merging=false),max_states=50000,threshold_before_merge=1)\n}}}");
}

virtual shared_ptr<MergeScoringFunctionDFP> create_component(const plugins::Options &, const utils::Context &) const override {
return make_shared<MergeScoringFunctionDFP>();
add_merge_scoring_function_options_to_feature(*this);
}
};

Expand Down
4 changes: 2 additions & 2 deletions src/search/merge_and_shrink/merge_scoring_function_dfp.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ class MergeScoringFunctionDFP : public MergeScoringFunction {
protected:
virtual std::string name() const override;
public:
MergeScoringFunctionDFP() = default;
explicit MergeScoringFunctionDFP(const plugins::Options &options);
virtual ~MergeScoringFunctionDFP() override = default;
virtual std::vector<double> compute_scores(
const FactoredTransitionSystem &fts,
const std::vector<std::pair<int, int>> &merge_candidates) override;
const std::vector<std::shared_ptr<MergeCandidate>> &merge_candidates) override;

virtual bool requires_init_distances() const override {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@
using namespace std;

namespace merge_and_shrink {
MergeScoringFunctionGoalRelevance::MergeScoringFunctionGoalRelevance(
const plugins::Options &options)
: MergeScoringFunction(options) {
}

vector<double> MergeScoringFunctionGoalRelevance::compute_scores(
const FactoredTransitionSystem &fts,
const vector<pair<int, int>> &merge_candidates) {
const vector<shared_ptr<MergeCandidate>> &merge_candidates) {
int num_ts = fts.get_size();
vector<bool> goal_relevant(num_ts, false);
for (int ts_index : fts) {
Expand All @@ -23,9 +28,9 @@ vector<double> MergeScoringFunctionGoalRelevance::compute_scores(

vector<double> scores;
scores.reserve(merge_candidates.size());
for (pair<int, int> merge_candidate : merge_candidates) {
int ts_index1 = merge_candidate.first;
int ts_index2 = merge_candidate.second;
for (const auto &merge_candidate : merge_candidates) {
int ts_index1 = merge_candidate->index1;
int ts_index2 = merge_candidate->index2;
int score = INF;
if (goal_relevant[ts_index1] || goal_relevant[ts_index2]) {
score = 0;
Expand All @@ -48,10 +53,7 @@ class MergeScoringFunctionGoalRelevanceFeature : public plugins::TypedFeature<Me
"least one of the two transition systems of the merge candidate is "
"goal relevant in the sense that there is an abstract non-goal state. "
"All other candidates get a score of positive infinity.");
}

virtual shared_ptr<MergeScoringFunctionGoalRelevance> create_component(const plugins::Options &, const utils::Context &) const override {
return make_shared<MergeScoringFunctionGoalRelevance>();
add_merge_scoring_function_options_to_feature(*this);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ class MergeScoringFunctionGoalRelevance : public MergeScoringFunction {
protected:
virtual std::string name() const override;
public:
MergeScoringFunctionGoalRelevance() = default;
MergeScoringFunctionGoalRelevance(const plugins::Options &options);
virtual ~MergeScoringFunctionGoalRelevance() override = default;
virtual std::vector<double> compute_scores(
const FactoredTransitionSystem &fts,
const std::vector<std::pair<int, int>> &merge_candidates) override;
const std::vector<std::shared_ptr<MergeCandidate>> &merge_candidates) override;

virtual bool requires_init_distances() const override {
return false;
Expand Down
77 changes: 44 additions & 33 deletions src/search/merge_and_shrink/merge_scoring_function_miasm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ using namespace std;
namespace merge_and_shrink {
MergeScoringFunctionMIASM::MergeScoringFunctionMIASM(
const plugins::Options &options)
: shrink_strategy(options.get<shared_ptr<ShrinkStrategy>>("shrink_strategy")),
: MergeScoringFunction(options),
shrink_strategy(options.get<shared_ptr<ShrinkStrategy>>("shrink_strategy")),
max_states(options.get<int>("max_states")),
max_states_before_merge(options.get<int>("max_states_before_merge")),
shrink_threshold_before_merge(options.get<int>("threshold_before_merge")),
Expand All @@ -25,43 +26,52 @@ MergeScoringFunctionMIASM::MergeScoringFunctionMIASM(

vector<double> MergeScoringFunctionMIASM::compute_scores(
const FactoredTransitionSystem &fts,
const vector<pair<int, int>> &merge_candidates) {
const vector<shared_ptr<MergeCandidate>> &merge_candidates) {
vector<double> scores;
scores.reserve(merge_candidates.size());
for (pair<int, int> merge_candidate : merge_candidates) {
int index1 = merge_candidate.first;
int index2 = merge_candidate.second;
unique_ptr<TransitionSystem> product = shrink_before_merge_externally(
fts,
index1,
index2,
*shrink_strategy,
max_states,
max_states_before_merge,
shrink_threshold_before_merge,
silent_log);
for (const auto &merge_candidate : merge_candidates) {
double score;
int id = merge_candidate->id;
if (use_caching && cached_scores.count(id)) {
score = cached_scores[id];
} else {
int index1 = merge_candidate->index1;
int index2 = merge_candidate->index2;
unique_ptr<TransitionSystem> product = shrink_before_merge_externally(
fts,
index1,
index2,
*shrink_strategy,
max_states,
max_states_before_merge,
shrink_threshold_before_merge,
silent_log);

// Compute distances for the product and count the alive states.
unique_ptr<Distances> distances = utils::make_unique_ptr<Distances>(*product);
const bool compute_init_distances = true;
const bool compute_goal_distances = true;
distances->compute_distances(compute_init_distances, compute_goal_distances, silent_log);
int num_states = product->get_size();
int alive_states_count = 0;
for (int state = 0; state < num_states; ++state) {
if (distances->get_init_distance(state) != INF &&
distances->get_goal_distance(state) != INF) {
++alive_states_count;
// Compute distances for the product and count the alive states.
unique_ptr<Distances> distances = utils::make_unique_ptr<Distances>(*product);
const bool compute_init_distances = true;
const bool compute_goal_distances = true;
distances->compute_distances(compute_init_distances, compute_goal_distances, silent_log);
int num_states = product->get_size();
int alive_states_count = 0;
for (int state = 0; state < num_states; ++state) {
if (distances->get_init_distance(state) != INF &&
distances->get_goal_distance(state) != INF) {
++alive_states_count;
}
}
}

/*
Compute the score as the ratio of alive states of the product
compared to the number of states of the full product.
*/
assert(num_states);
double score = static_cast<double>(alive_states_count) /
static_cast<double>(num_states);
/*
Compute the score as the ratio of alive states of the product
compared to the number of states of the full product.
*/
assert(num_states);
score = static_cast<double>(alive_states_count) /
static_cast<double>(num_states);
if (use_caching) {
cached_scores[id] = score;
}
}
scores.push_back(score);
}
return scores;
Expand Down Expand Up @@ -127,6 +137,7 @@ class MergeScoringFunctionMIASMFeature : public plugins::TypedFeature<MergeScori
"amount of possible pruning, merge-and-shrink should be configured to "
"use full pruning, i.e. {{{prune_unreachable_states=true}}} and {{{"
"prune_irrelevant_states=true}}} (the default).");
add_merge_scoring_function_options_to_feature(*this);
}

virtual shared_ptr<MergeScoringFunctionMIASM> create_component(const plugins::Options &options, const utils::Context &context) const override {
Expand Down
2 changes: 1 addition & 1 deletion src/search/merge_and_shrink/merge_scoring_function_miasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class MergeScoringFunctionMIASM : public MergeScoringFunction {
virtual ~MergeScoringFunctionMIASM() override = default;
virtual std::vector<double> compute_scores(
const FactoredTransitionSystem &fts,
const std::vector<std::pair<int, int>> &merge_candidates) override;
const std::vector<std::shared_ptr<MergeCandidate>> &merge_candidates) override;

virtual bool requires_init_distances() const override {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ using namespace std;
namespace merge_and_shrink {
MergeScoringFunctionSingleRandom::MergeScoringFunctionSingleRandom(
const plugins::Options &options)
: random_seed(options.get<int>("random_seed")),
: MergeScoringFunction(options),
random_seed(options.get<int>("random_seed")),
rng(utils::parse_rng_from_options(options)) {
}

vector<double> MergeScoringFunctionSingleRandom::compute_scores(
const FactoredTransitionSystem &,
const vector<pair<int, int>> &merge_candidates) {
const vector<shared_ptr<MergeCandidate>> &merge_candidates) {
int chosen_index = rng->random(merge_candidates.size());
vector<double> scores;
scores.reserve(merge_candidates.size());
Expand Down Expand Up @@ -53,7 +54,7 @@ class MergeScoringFunctionSingleRandomFeature : public plugins::TypedFeature<Mer
document_synopsis(
"This scoring function assigns exactly one merge candidate a score of "
"0, chosen randomly, and infinity to all others.");

add_merge_scoring_function_options_to_feature(*this);
utils::add_rng_options(*this);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class MergeScoringFunctionSingleRandom : public MergeScoringFunction {
virtual ~MergeScoringFunctionSingleRandom() override = default;
virtual std::vector<double> compute_scores(
const FactoredTransitionSystem &fts,
const std::vector<std::pair<int, int>> &merge_candidates) override;
const std::vector<std::shared_ptr<MergeCandidate>> &merge_candidates) override;

virtual bool requires_init_distances() const override {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ using namespace std;
namespace merge_and_shrink {
MergeScoringFunctionTotalOrder::MergeScoringFunctionTotalOrder(
const plugins::Options &options)
: atomic_ts_order(options.get<AtomicTSOrder>("atomic_ts_order")),
: MergeScoringFunction(options),
atomic_ts_order(options.get<AtomicTSOrder>("atomic_ts_order")),
product_ts_order(options.get<ProductTSOrder>("product_ts_order")),
atomic_before_product(options.get<bool>("atomic_before_product")),
random_seed(options.get<int>("random_seed")),
Expand All @@ -27,15 +28,15 @@ MergeScoringFunctionTotalOrder::MergeScoringFunctionTotalOrder(

vector<double> MergeScoringFunctionTotalOrder::compute_scores(
const FactoredTransitionSystem &,
const vector<pair<int, int>> &merge_candidates) {
const vector<shared_ptr<MergeCandidate>> &merge_candidates) {
assert(initialized);
vector<double> scores;
scores.reserve(merge_candidates.size());
for (size_t candidate_index = 0; candidate_index < merge_candidates.size();
++candidate_index) {
pair<int, int> merge_candidate = merge_candidates[candidate_index];
int ts_index1 = merge_candidate.first;
int ts_index2 = merge_candidate.second;
const auto &merge_candidate = merge_candidates[candidate_index];
int ts_index1 = merge_candidate->index1;
int ts_index2 = merge_candidate->index2;
for (size_t merge_candidate_order_index = 0;
merge_candidate_order_index < merge_candidate_order.size();
++merge_candidate_order_index) {
Expand Down Expand Up @@ -199,6 +200,7 @@ class MergeScoringFunctionTotalOrderFeature : public plugins::TypedFeature<Merge
"if used alone in a score based filtering merge selector, can be used "
"to emulate the corresponding (precomputed) linear merge strategies "
"reverse level/level (independently of the other options).");
add_merge_scoring_function_options_to_feature(*this);
MergeScoringFunctionTotalOrder::add_options_to_feature(*this);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class MergeScoringFunctionTotalOrder : public MergeScoringFunction {
virtual ~MergeScoringFunctionTotalOrder() override = default;
virtual std::vector<double> compute_scores(
const FactoredTransitionSystem &fts,
const std::vector<std::pair<int, int>> &merge_candidates) override;
const std::vector<std::shared_ptr<MergeCandidate>> &merge_candidates) override;
virtual void initialize(const TaskProxy &task_proxy) override;
static void add_options_to_feature(plugins::Feature &feature);

Expand Down
Loading

0 comments on commit 5044d41

Please sign in to comment.