diff --git a/README.md b/README.md index 07b92755..53da4a35 100644 --- a/README.md +++ b/README.md @@ -41,26 +41,31 @@ python3 run_gnn.py -m -r +``` +for from `llg, dlg, slg, glg, flg` or generate them all at once with +``` +sh dataset/generate_all_graphs_gnn.sh ``` #### Domain-dependent training -Requires packages in `requirements.txt` or alternatively use the singularity container as in [Search -Evaluation](#search-evaluation). To train, go into ```learner``` directory (`cd learner`). Then run +Requires packages in `requirements.txt` or alternatively use the singularity container as in [Search](#search). To train, go +into ```learner``` directory (`cd learner`) and run ``` python3 train_gnn.py -m RGNN -r llg -d goose--only --save-file ``` -where you replace `````` by any domain from ```blocks, ferry, gripper, n-puzzle, sokoban, spanner, visitall, -visitsome``` and `````` is the name of the save file ending in `.dt` for the trained weights of the models which -would then be located in ```trained_models/``` after training. +where you replace `` by any domain from `blocks, ferry, gripper, n-puzzle, sokoban, spanner, visitall, +visitsome` and `` is the name of the save file ending in `.dt` for the trained weights of the models which +would then be located in `trained_models/` after training. ## Kernels ### Search diff --git a/downward/src/search/heuristics/goose_heuristic.cc b/downward/src/search/heuristics/goose_heuristic.cc index 1906aff2..48a678b7 100644 --- a/downward/src/search/heuristics/goose_heuristic.cc +++ b/downward/src/search/heuristics/goose_heuristic.cc @@ -13,10 +13,8 @@ using std::string; namespace goose_heuristic { GooseHeuristic::GooseHeuristic(const plugins::Options &opts) : Heuristic(opts) { - initialise_model(opts); initialise_fact_strings(); - } void GooseHeuristic::initialise_model(const plugins::Options &opts) { @@ -42,48 +40,10 @@ void GooseHeuristic::initialise_model(const plugins::Options &opts) { // python will be printed to stderr, even if it is not an error. sys.attr("stderr") = sys.attr("stdout"); - // A really disgusting hack because FeaturePlugin cannot parse string options - std::string config_path; - switch (opts.get("graph")) - { - case 0: config_path = "slg"; break; - case 1: config_path = "flg"; break; - case 2: config_path = "dlg"; break; - case 3: config_path = "llg"; break; - default: - std::cout << "Unknown enum of graph representation" << std::endl; - exit(-1); - } - - // Parse paths from file at config_path - std::string model_path; - std::string domain_file; - std::string instance_file; - - std::string line; - std::ifstream config_file(config_path); - int file_line = 0; - - // TODO see https://github.com/aibasel/downward/pull/170 - while (getline(config_file, line)) { - switch (file_line) { - case 0: - model_path = line; - break; - case 1: - domain_file = line; - break; - case 2: - instance_file = line; - break; - default: - std::cout << "config file " << config_path - << " must only have 3 lines" << std::endl; - exit(-1); - } - file_line++; - } - config_file.close(); + // Read paths + std::string model_path = opts.get("model_path"); + std::string domain_file = opts.get("domain_file"); + std::string instance_file = opts.get("instance_file"); // Throw everything into Python code std::cout << "Trying to load model from file " << model_path << " ...\n"; @@ -189,27 +149,19 @@ class GooseHeuristicFeature : public plugins::TypedFeature( - "graph", - "0: slg, 1: flg, 2: llg, 3: glg", - "-1"); - - // add_option does not work with - - // add_option( - // "model_path", - // "path to trained model weights of file type .dt", - // "default_value.dt"); - - // add_option( - // "domain_file", - // "Path to the domain file.", - // "default_file.pddl"); - - // add_option( - // "instance_file", - // "Path to the instance file.", - // "default_file.pddl"); + // https://github.com/aibasel/downward/pull/170 for string options + add_option( + "model_path", + "path to trained model weights of file type .dt", + "default_value.dt"); + add_option( + "domain_file", + "Path to the domain file.", + "default_file.pddl"); + add_option( + "instance_file", + "Path to the instance file.", + "default_file.pddl"); Heuristic::add_options_to_feature(*this); diff --git a/downward/src/search/parser/abstract_syntax_tree.cc b/downward/src/search/parser/abstract_syntax_tree.cc index 5aecdb72..27e5b670 100644 --- a/downward/src/search/parser/abstract_syntax_tree.cc +++ b/downward/src/search/parser/abstract_syntax_tree.cc @@ -419,6 +419,8 @@ DecoratedASTNodePtr LiteralNode::decorate(DecorateContext &context) const { switch (value.type) { case TokenType::BOOLEAN: return utils::make_unique_ptr(value.content); + case TokenType::STRING: + return utils::make_unique_ptr(value.content); case TokenType::INTEGER: return utils::make_unique_ptr(value.content); case TokenType::FLOAT: @@ -440,6 +442,8 @@ const plugins::Type &LiteralNode::get_type(DecorateContext &context) const { switch (value.type) { case TokenType::BOOLEAN: return plugins::TypeRegistry::instance()->get_type(); + case TokenType::STRING: + return plugins::TypeRegistry::instance()->get_type(); case TokenType::INTEGER: return plugins::TypeRegistry::instance()->get_type(); case TokenType::FLOAT: @@ -454,4 +458,4 @@ const plugins::Type &LiteralNode::get_type(DecorateContext &context) const { token_type_name(value.type) + "'."); } } -} +} \ No newline at end of file diff --git a/downward/src/search/parser/decorated_abstract_syntax_tree.cc b/downward/src/search/parser/decorated_abstract_syntax_tree.cc index 3a401d9e..068ee593 100644 --- a/downward/src/search/parser/decorated_abstract_syntax_tree.cc +++ b/downward/src/search/parser/decorated_abstract_syntax_tree.cc @@ -218,6 +218,19 @@ void BoolLiteralNode::dump(string indent) const { cout << indent << "BOOL: " << value << endl; } +StringLiteralNode::StringLiteralNode(const string &value) + : value(value) { +} + +plugins::Any StringLiteralNode::construct(ConstructContext &context) const { + utils::TraceBlock block(context, "Constructing string value from '" + value + "'"); + return value; +} + +void StringLiteralNode::dump(string indent) const { + cout << indent << "STRING: " << value << endl; +} + IntLiteralNode::IntLiteralNode(const string &value) : value(value) { } @@ -473,6 +486,18 @@ shared_ptr BoolLiteralNode::clone_shared() const { return make_shared(*this); } +StringLiteralNode::StringLiteralNode(const StringLiteralNode &other) + : value(other.value) { +} + +unique_ptr StringLiteralNode::clone() const { + return utils::make_unique_ptr(*this); +} + +shared_ptr StringLiteralNode::clone_shared() const { + return make_shared(*this); +} + IntLiteralNode::IntLiteralNode(const IntLiteralNode &other) : value(other.value) { } @@ -534,4 +559,4 @@ unique_ptr CheckBoundsNode::clone() const { shared_ptr CheckBoundsNode::clone_shared() const { return make_shared(*this); } -} +} \ No newline at end of file diff --git a/downward/src/search/parser/decorated_abstract_syntax_tree.h b/downward/src/search/parser/decorated_abstract_syntax_tree.h index 0094f887..6561560e 100644 --- a/downward/src/search/parser/decorated_abstract_syntax_tree.h +++ b/downward/src/search/parser/decorated_abstract_syntax_tree.h @@ -157,6 +157,20 @@ class BoolLiteralNode : public DecoratedASTNode { BoolLiteralNode(const BoolLiteralNode &other); }; +class StringLiteralNode : public DecoratedASTNode { + std::string value; +public: + StringLiteralNode(const std::string &value); + + plugins::Any construct(ConstructContext &context) const override; + void dump(std::string indent) const override; + + // TODO: once we get rid of lazy construction, this should no longer be necessary. + virtual std::unique_ptr clone() const override; + virtual std::shared_ptr clone_shared() const override; + StringLiteralNode(const StringLiteralNode &other); +}; + class IntLiteralNode : public DecoratedASTNode { std::string value; public: @@ -234,4 +248,4 @@ class CheckBoundsNode : public DecoratedASTNode { CheckBoundsNode(const CheckBoundsNode &other); }; } -#endif +#endif \ No newline at end of file diff --git a/downward/src/search/parser/lexical_analyzer.cc b/downward/src/search/parser/lexical_analyzer.cc index a127aed9..f31f230d 100644 --- a/downward/src/search/parser/lexical_analyzer.cc +++ b/downward/src/search/parser/lexical_analyzer.cc @@ -29,6 +29,8 @@ static vector> construct_token_type_expressions() { {TokenType::INTEGER, R"([+-]?(infinity|\d+([kmg]\b)?))"}, {TokenType::BOOLEAN, R"(true|false)"}, + // TODO: support quoted strings. + {TokenType::STRING, R"("([^"]*)\")"}, {TokenType::LET, R"(let)"}, {TokenType::IDENTIFIER, R"([a-zA-Z_]\w*)"} }; @@ -59,7 +61,13 @@ TokenStream split_tokens(const string &text) { TokenType token_type = type_and_expression.first; const regex &expression = type_and_expression.second; if (regex_search(start, end, match, expression)) { - tokens.push_back({utils::tolower(match[1]), token_type}); + string value; + if (token_type == TokenType::STRING) { + value = match[2]; + } else { + value = utils::tolower(match[1]); + } + tokens.push_back({value, token_type}); start += match[0].length(); has_match = true; break; @@ -86,4 +94,4 @@ TokenStream split_tokens(const string &text) { } return TokenStream(move(tokens)); } -} +} \ No newline at end of file diff --git a/downward/src/search/parser/syntax_analyzer.cc b/downward/src/search/parser/syntax_analyzer.cc index ffcafbfa..62f4fbc3 100644 --- a/downward/src/search/parser/syntax_analyzer.cc +++ b/downward/src/search/parser/syntax_analyzer.cc @@ -162,6 +162,7 @@ static unordered_set literal_tokens { TokenType::FLOAT, TokenType::INTEGER, TokenType::BOOLEAN, + TokenType::STRING, TokenType::IDENTIFIER }; @@ -193,7 +194,8 @@ static ASTNodePtr parse_list(TokenStream &tokens, SyntaxAnalyzerContext &context static vector PARSE_NODE_TOKEN_TYPES = { TokenType::LET, TokenType::IDENTIFIER, TokenType::BOOLEAN, - TokenType::INTEGER, TokenType::FLOAT, TokenType::OPENING_BRACKET}; + TokenType::STRING, TokenType::INTEGER, TokenType::FLOAT, + TokenType::OPENING_BRACKET}; static ASTNodePtr parse_node(TokenStream &tokens, SyntaxAnalyzerContext &context) { @@ -220,6 +222,7 @@ static ASTNodePtr parse_node(TokenStream &tokens, return parse_literal(tokens, context); } case TokenType::BOOLEAN: + case TokenType::STRING: case TokenType::INTEGER: case TokenType::FLOAT: return parse_literal(tokens, context); @@ -244,4 +247,4 @@ ASTNodePtr parse(TokenStream &tokens) { } return node; } -} +} \ No newline at end of file diff --git a/downward/src/search/parser/token_stream.cc b/downward/src/search/parser/token_stream.cc index 7879be17..24695feb 100644 --- a/downward/src/search/parser/token_stream.cc +++ b/downward/src/search/parser/token_stream.cc @@ -96,12 +96,12 @@ string token_type_name(TokenType token_type) { return "Float"; case TokenType::BOOLEAN: return "Boolean"; + case TokenType::STRING: + return "String"; case TokenType::IDENTIFIER: return "Identifier"; case TokenType::LET: return "Let"; - case TokenType::PATH: - return "Path"; default: ABORT("Unknown token type."); } @@ -116,4 +116,4 @@ ostream &operator<<(ostream &out, const Token &token) { out << ""; return out; } -} +} \ No newline at end of file diff --git a/downward/src/search/parser/token_stream.h b/downward/src/search/parser/token_stream.h index 74420c26..01daaddf 100644 --- a/downward/src/search/parser/token_stream.h +++ b/downward/src/search/parser/token_stream.h @@ -19,9 +19,9 @@ enum class TokenType { INTEGER, FLOAT, BOOLEAN, + STRING, IDENTIFIER, - LET, - PATH, + LET }; struct Token { @@ -59,4 +59,4 @@ struct hash { } }; } -#endif +#endif \ No newline at end of file diff --git a/downward/src/search/plugins/types.cc b/downward/src/search/plugins/types.cc index 117c139b..d694f834 100644 --- a/downward/src/search/plugins/types.cc +++ b/downward/src/search/plugins/types.cc @@ -292,6 +292,7 @@ BasicType TypeRegistry::NO_TYPE = BasicType(typeid(void), ""); TypeRegistry::TypeRegistry() { insert_basic_type(); + insert_basic_type(); insert_basic_type(); insert_basic_type(); } @@ -345,4 +346,4 @@ const Type &TypeRegistry::get_nonlist_type(type_index type) const { } return *registered_types.at(type); } -} +} \ No newline at end of file diff --git a/learner/.gitignore b/learner/.gitignore index 21235fd7..2f289fa4 100644 --- a/learner/.gitignore +++ b/learner/.gitignore @@ -13,6 +13,7 @@ saved_models* data lifted plans +plots slg flg diff --git a/learner/dataset/generate_all_graphs_gnn.sh b/learner/dataset/generate_all_graphs_gnn.sh new file mode 100644 index 00000000..dbb4a90f --- /dev/null +++ b/learner/dataset/generate_all_graphs_gnn.sh @@ -0,0 +1,5 @@ +for rep in llg slg dlg glg flg +do + echo "python3 dataset/generate_graphs_gnn.py $rep --regenerate" + python3 dataset/generate_graphs_gnn.py $rep --regenerate +done diff --git a/learner/dataset/generate_all_graphs_kernel.sh b/learner/dataset/generate_all_graphs_kernel.sh new file mode 100644 index 00000000..dddbe333 --- /dev/null +++ b/learner/dataset/generate_all_graphs_kernel.sh @@ -0,0 +1,5 @@ +for rep in llg slg dlg glg flg +do + echo "python3 dataset/generate_graphs_kernel.py $rep --regenerate" + python3 dataset/generate_graphs_kernel.py $rep --regenerate +done diff --git a/learner/scripts/generate_graphs_gnn.py b/learner/dataset/generate_graphs_gnn.py similarity index 93% rename from learner/scripts/generate_graphs_gnn.py rename to learner/dataset/generate_graphs_gnn.py index 5272f6d0..b7e746eb 100644 --- a/learner/scripts/generate_graphs_gnn.py +++ b/learner/dataset/generate_graphs_gnn.py @@ -3,7 +3,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) import argparse from representation import REPRESENTATIONS -from dataset.graphs_gnn import gen_graph_rep +from .graphs_gnn import gen_graph_rep if __name__ == "__main__": diff --git a/learner/scripts/generate_graphs_kernel.py b/learner/dataset/generate_graphs_kernel.py similarity index 93% rename from learner/scripts/generate_graphs_kernel.py rename to learner/dataset/generate_graphs_kernel.py index e072de96..862c6945 100644 --- a/learner/scripts/generate_graphs_kernel.py +++ b/learner/dataset/generate_graphs_kernel.py @@ -3,7 +3,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) import argparse from representation import REPRESENTATIONS -from dataset.graphs_kernel import gen_graph_rep +from .graphs_kernel import gen_graph_rep if __name__ == "__main__": diff --git a/learner/dataset/graphs_gnn.py b/learner/dataset/graphs_gnn.py index 711f2636..c019475e 100644 --- a/learner/dataset/graphs_gnn.py +++ b/learner/dataset/graphs_gnn.py @@ -65,9 +65,7 @@ def get_graph_data( print("Loading train data...") print("NOTE: the data has been precomputed and saved.") - print("Exec") - print("\tpython3 scripts/generate_graphs_gnn.py --regenerate") - print("if representation has been updated!") + print("Exec 'python3 scripts/generate_graphs_gnn.py --regenerate' if representation has been updated!") path = get_data_dir_path(representation=representation) print(f"Path to data: {path}") diff --git a/learner/dataset/graphs_kernel.py b/learner/dataset/graphs_kernel.py index 364ee64d..41e59d9a 100644 --- a/learner/dataset/graphs_kernel.py +++ b/learner/dataset/graphs_kernel.py @@ -56,9 +56,7 @@ def get_graph_data( print("Loading train data...") print("NOTE: the data has been precomputed and saved.") - print("Exec") - print("\tpython3 scripts/generate_graphs_kernel.py --regenerate") - print("if representation has been updated!") + print("Exec 'python3 scripts/generate_graphs_kernel.py --regenerate' if representation has been updated!") path = get_data_dir_path(representation=representation) print(f"Path to data: {path}") @@ -129,9 +127,9 @@ def gen_graph_rep( ) -> None: """ Generate graph representations from saved optimal plans. """ - tasks = get_ipc_domain_problem_files(del_free=False) + # tasks = get_ipc_domain_problem_files(del_free=False) # tasks += get_all_htg_instance_files(split=True) - tasks += get_train_goose_instance_files() + tasks = get_train_goose_instance_files() new_generated = 0 pbar = tqdm(tasks) diff --git a/learner/representation/dlg.py b/learner/representation/dlg.py index c69a56c7..e85e1b60 100644 --- a/learner/representation/dlg.py +++ b/learner/representation/dlg.py @@ -38,8 +38,10 @@ def _compute_graph_representation(self) -> None: # these features may get updated in state encoding if proposition in positive_goals: x_p = self._one_hot_node(DLG_FEATURES.POSITIVE_GOAL.value) + self._pos_goal_nodes.add(node_p) elif proposition in negative_goals: x_p = self._one_hot_node(DLG_FEATURES.NEGATIVE_GOAL.value) + self._neg_goal_nodes.add(node_p) else: x_p = self._zero_node() G.add_node(node_p, x=x_p) diff --git a/learner/representation/glg.py b/learner/representation/glg.py index 535d5011..b8dfc981 100644 --- a/learner/representation/glg.py +++ b/learner/representation/glg.py @@ -41,9 +41,11 @@ def _compute_graph_representation(self) -> None: node_p = self._proposition_to_str(proposition) # these features may get updated in state encoding if proposition in positive_goals: - x_p=self._one_hot_node(GLG_FEATURES.POSITIVE_GOAL.value) + x_p = self._one_hot_node(GLG_FEATURES.POSITIVE_GOAL.value) + self._pos_goal_nodes.add(node_p) elif proposition in negative_goals: - x_p=self._one_hot_node(GLG_FEATURES.NEGATIVE_GOAL.value) + x_p = self._one_hot_node(GLG_FEATURES.NEGATIVE_GOAL.value) + self._neg_goal_nodes.add(node_p) else: x_p=self._zero_node() G.add_node(node_p, x=x_p) diff --git a/learner/representation/slg.py b/learner/representation/slg.py index 75ef7554..2159fa12 100644 --- a/learner/representation/slg.py +++ b/learner/representation/slg.py @@ -89,8 +89,10 @@ def _compute_graph_representation(self) -> None: # these features may get updated in state encoding if proposition in positive_goals: x_p = self._one_hot_node(SLG_FEATURES.POSITIVE_GOAL.value) + self._pos_goal_nodes.add(node_p) elif proposition in negative_goals: x_p = self._one_hot_node(SLG_FEATURES.NEGATIVE_GOAL.value) + self._neg_goal_nodes.add(node_p) else: x_p = self._zero_node() G.add_node(node_p, x=x_p) @@ -135,3 +137,18 @@ def state_to_tensor(self, state: State) -> Tuple[Tensor, Tensor]: x[self._node_to_i[p]][SLG_FEATURES.STATE.value] = 1 return x, self.edge_indices + + def state_to_cgraph(self, state: State) -> CGraph: + """ States are represented as a list of (pred, [args]) """ + c_graph = self.c_graph.copy() + + for p in state: + + # activated proposition overlaps with a goal Atom or NegatedAtom + if p in self._pos_goal_nodes: + c_graph.nodes[p]['colour'] = c_graph.nodes[p]['colour']+ACTIVATED_POS_GOAL_COLOUR_SUFFIX + elif p in self._neg_goal_nodes: + c_graph.nodes[p]['colour'] = c_graph.nodes[p]['colour']+ACTIVATED_NEG_GOAL_COLOUR_SUFFIX + + return c_graph + \ No newline at end of file diff --git a/learner/run_gnn.py b/learner/run_gnn.py index 31330494..8882002d 100644 --- a/learner/run_gnn.py +++ b/learner/run_gnn.py @@ -32,4 +32,5 @@ seed=0, ) + print(cmd) os.system(cmd) diff --git a/learner/scripts/generate_all_graphs.sh b/learner/scripts/generate_all_graphs.sh deleted file mode 100644 index 1ad626ef..00000000 --- a/learner/scripts/generate_all_graphs.sh +++ /dev/null @@ -1,5 +0,0 @@ -for rep in ldg-el fdg-el sdg-el gdg-el -do - echo "python3 scripts/generate_graphs.py $rep --regenerate" - python3 scripts/generate_graphs.py $rep --regenerate -done diff --git a/learner/scripts/.gitignore b/learner/scripts_gnn/.gitignore similarity index 100% rename from learner/scripts/.gitignore rename to learner/scripts_gnn/.gitignore diff --git a/learner/scripts/cluster1_job_3090 b/learner/scripts_gnn/cluster1_job_3090 similarity index 100% rename from learner/scripts/cluster1_job_3090 rename to learner/scripts_gnn/cluster1_job_3090 diff --git a/learner/scripts/cluster1_job_a6000 b/learner/scripts_gnn/cluster1_job_a6000 similarity index 100% rename from learner/scripts/cluster1_job_a6000 rename to learner/scripts_gnn/cluster1_job_a6000 diff --git a/learner/scripts/cluster1_job_any b/learner/scripts_gnn/cluster1_job_any similarity index 100% rename from learner/scripts/cluster1_job_any rename to learner/scripts_gnn/cluster1_job_any diff --git a/learner/scripts/cluster1_job_planopt b/learner/scripts_gnn/cluster1_job_planopt similarity index 100% rename from learner/scripts/cluster1_job_planopt rename to learner/scripts_gnn/cluster1_job_planopt diff --git a/learner/scripts/collect_cluster1_logs.sh b/learner/scripts_gnn/collect_cluster1_logs.sh similarity index 100% rename from learner/scripts/collect_cluster1_logs.sh rename to learner/scripts_gnn/collect_cluster1_logs.sh diff --git a/learner/scripts/predict_dd_and_di.py b/learner/scripts_gnn/predict_dd_and_di.py similarity index 100% rename from learner/scripts/predict_dd_and_di.py rename to learner/scripts_gnn/predict_dd_and_di.py diff --git a/learner/scripts/submit_dd_train_only.sh b/learner/scripts_gnn/submit_dd_train_only.sh similarity index 100% rename from learner/scripts/submit_dd_train_only.sh rename to learner/scripts_gnn/submit_dd_train_only.sh diff --git a/learner/scripts/submit_dd_train_validate_test.sh b/learner/scripts_gnn/submit_dd_train_validate_test.sh similarity index 100% rename from learner/scripts/submit_dd_train_validate_test.sh rename to learner/scripts_gnn/submit_dd_train_validate_test.sh diff --git a/learner/scripts/submit_di_train_only.sh b/learner/scripts_gnn/submit_di_train_only.sh similarity index 100% rename from learner/scripts/submit_di_train_only.sh rename to learner/scripts_gnn/submit_di_train_only.sh diff --git a/learner/scripts/submit_di_train_validate_test.sh b/learner/scripts_gnn/submit_di_train_validate_test.sh similarity index 100% rename from learner/scripts/submit_di_train_validate_test.sh rename to learner/scripts_gnn/submit_di_train_validate_test.sh diff --git a/learner/scripts/submit_predict.sh b/learner/scripts_gnn/submit_predict.sh similarity index 100% rename from learner/scripts/submit_predict.sh rename to learner/scripts_gnn/submit_predict.sh diff --git a/learner/scripts/train_validate_test_dd.py b/learner/scripts_gnn/train_validate_test_dd.py similarity index 100% rename from learner/scripts/train_validate_test_dd.py rename to learner/scripts_gnn/train_validate_test_dd.py diff --git a/learner/scripts/train_validate_test_di.py b/learner/scripts_gnn/train_validate_test_di.py similarity index 100% rename from learner/scripts/train_validate_test_di.py rename to learner/scripts_gnn/train_validate_test_di.py diff --git a/learner/scripts_kernel/cross_validate_all.sh b/learner/scripts_kernel/cross_validate_all.sh new file mode 100644 index 00000000..a628aa2c --- /dev/null +++ b/learner/scripts_kernel/cross_validate_all.sh @@ -0,0 +1,18 @@ +LOG_DIR=logs/train_kernel + +mkdir -p $LOG_DIR + +for l in 0 1 2 3 4 +do + for k in wl + do + for r in llg slg dlg glg + do + for d in gripper spanner visitall visitsome blocks ferry sokoban n-puzzle + do + echo $r $k $l $d + python3 train_kernel.py -k $k -l $l -r $r -d $d --visualise --cross-validation > $LOG_DIR/${r}_${d}_${k}_${l}.log + done + done + done +done \ No newline at end of file diff --git a/learner/scripts_kernel/train_all.sh b/learner/scripts_kernel/train_all.sh new file mode 100644 index 00000000..3cade11c --- /dev/null +++ b/learner/scripts_kernel/train_all.sh @@ -0,0 +1,18 @@ +LOG_DIR=logs/train_kernel + +mkdir -p $LOG_DIR + +for l in 0 1 2 3 4 +do + for k in wl + do + for r in llg slg dlg glg + do + for d in gripper spanner visitall visitsome blocks ferry sokoban n-puzzle + do + echo $r $k $l $d + python3 train_kernel.py -k $k -l $l -r $r -d $d --save-file ${r}_${d}_${k}_${l} > $LOG_DIR/${r}_${d}_${k}_${l}.log + done + done + done +done \ No newline at end of file diff --git a/learner/test_gnn.sh b/learner/test_gnn.sh index c773d86d..a0a53ffb 100644 --- a/learner/test_gnn.sh +++ b/learner/test_gnn.sh @@ -1 +1,2 @@ -singularity exec --nv ../gpu.sif python3 run_gnn.py ../benchmarks/goose/gripper/domain.pddl ../benchmarks/goose/gripper/test/gripper-n20.pddl -m saved_models/dd_llg_gripper.dt -r llg \ No newline at end of file +# singularity exec --nv ../gpu.sif python3 run_gnn.py ../benchmarks/goose/gripper/domain.pddl ../benchmarks/goose/gripper/test/gripper-n20.pddl -m saved_models/dd_llg_gripper.dt -r llg +singularity exec --nv ../gpu.sif python3 run_gnn.py ../benchmarks/goose/gripper/domain.pddl ../benchmarks/goose/gripper/test/gripper-n20.pddl -m saved_models/dd_slg_gripper.dt -r slg \ No newline at end of file diff --git a/learner/train_kernel.py b/learner/train_kernel.py index 0bb043f8..3ca03f41 100755 --- a/learner/train_kernel.py +++ b/learner/train_kernel.py @@ -1,17 +1,22 @@ """ Main training pipeline script. """ +import os import time import argparse import representation import kernels import numpy as np from dataset.dataset import get_dataset_from_args_kernels -from util.save_load import print_arguments +from util.save_load import print_arguments, save_sklearn_model from util.metrics import f1_macro +from util.visualise import get_confusion_matrix from sklearn.svm import LinearSVR, SVR from sklearn.model_selection import cross_validate from sklearn.metrics import make_scorer, mean_squared_error +import warnings +warnings.filterwarnings('ignore') + _MODELS = [ "linear-svr", @@ -20,13 +25,18 @@ _CV_FOLDS = 5 _MAX_MODEL_ITER = 10000 +_PLOT_DIR = "plots" +_SCORING = { + "mse": make_scorer(mean_squared_error), + "f1_macro": make_scorer(f1_macro) +} + -def create_parser(): +def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('-r', '--rep', type=str, required=True, choices=representation.REPRESENTATIONS, help="graph representation to use") - # TODO implement CGraph for SLG parser.add_argument('-k', '--kernel', type=str, required=True, choices=kernels.KERNELS, help="graph representation to use") @@ -48,16 +58,88 @@ def create_parser(): parser.add_argument('-s', '--seed', type=int, default=0, help="random seed") + parser.add_argument('--cross-validation', action='store_true', + help="performs cross validation scoring; otherwise train on whole dataset") + parser.add_argument('--save-file', type=str, default=None, + help="save file of model weights when not using --cross-validation") + parser.add_argument('--visualise', action='store_true', + help="visualise train and test predictions; only used with --cross-validation") parser.add_argument('--small-train', action="store_true", help="use small train set, useful for debugging") - parser.add_argument('--save-file', dest="save_file", type=str, default=None, - help="file to save model weights") - return parser + return parser.parse_args() + +def perform_training(X, y, model, args): + print(f"Training on entire {args.domain} for {model_name}...") + t = time.time() + model.fit(X, y) + print(f"Model training completed in {time.time()-t:.2f}s") + for metric in _SCORING: + print(f"train_{metric}: {_SCORING[metric](model, X, y):.2f}") + save_sklearn_model(model, args) + return + +def perform_cross_validation(X, y, model, args): + print(f"Performing {_CV_FOLDS}-fold cross validation on {model_name}...") + t = time.time() + scores = cross_validate( + model, X, y, + cv=_CV_FOLDS, scoring=_SCORING, return_train_score=True, n_jobs=-1, + return_estimator=args.visualise, return_indices=args.visualise, + ) + print(f"CV completed in {time.time() - t:.2f}s") + + for metric in _SCORING: + train_key = f"train_{metric}" + test_key = f"test_{metric}" + print(f"train_{metric}: {scores[train_key].mean():.2f} ± {scores[train_key].std():.2f}") + print(f"test_{metric}: {scores[test_key].mean():.2f} ± {scores[test_key].std():.2f}") + + if args.visualise: + """ Visualise predictions and save to file + Performs some redundant computations + """ + + if model_name == "svr": # kernel matrix case + raise NotImplementedError + + print("Saving visualisation...") + train_trues = [] + train_preds = [] + test_trues = [] + test_preds = [] + + for i in range(_CV_FOLDS): + estimator = scores["estimator"][i] + train_indices = scores["indices"]["train"][i] + test_indices = scores["indices"]["test"][i] + X_train = X[train_indices] + X_test = X[test_indices] + y_train = y[train_indices] + y_test = y[test_indices] + train_pred = estimator.predict(X_train) + test_pred = estimator.predict(X_test) + train_trues.append(y_train) + train_preds.append(train_pred) + test_trues.append(y_test) + test_preds.append(test_pred) + + y_true_train = np.concatenate(train_trues) + y_pred_train = np.concatenate(train_preds) + y_true_test = np.concatenate(test_trues) + y_pred_test = np.concatenate(test_preds) + + plt = get_confusion_matrix(y_true_train, y_pred_train, y_true_test, y_pred_test) + + os.makedirs(_PLOT_DIR, exist_ok=True) + file_name = _PLOT_DIR + "/" + "_".join([args.domain, args.rep, args.kernel, str(args.iterations)]) + ".pdf" + plt.savefig(file_name, bbox_inches="tight") + print(f"Visualisation saved at {file_name}") + return + if __name__ == "__main__": - parser = create_parser() - args = parser.parse_args() + args = parse_args() print_arguments(args) np.random.seed(args.seed) @@ -80,7 +162,7 @@ def create_parser(): "max_iter": _MAX_MODEL_ITER, } if model_name == "linear-svr": - model = LinearSVR(**kwargs) + model = LinearSVR(dual="auto", **kwargs) X = kernel.get_x(graphs) elif model_name == "svr": model = SVR(kernel="precomputed", **kwargs) @@ -89,20 +171,7 @@ def create_parser(): raise NotImplementedError print(f"Set up training data in {time.time()-t:.2f}s") - print(f"Performing {_CV_FOLDS}-fold cross validation on {model_name}...") - scoring = { - "mse": make_scorer(mean_squared_error), - "f1_macro": make_scorer(f1_macro) - } - scores = cross_validate(model, X, y, cv=_CV_FOLDS, scoring=scoring, return_train_score=True) - print(f"CV completed in {scores['fit_time'].sum()+scores['score_time'].sum():.2f}s") - - for metric in scoring: - train_key = f"train_{metric}" - test_key = f"test_{metric}" - print(f"train_{metric}: {scores[train_key].mean():.2f} ± {scores[train_key].std():.2f}") - print(f"test_{metric}: {scores[test_key].mean():.2f} ± {scores[test_key].std():.2f}") - - - - + if args.cross_validation: + perform_cross_validation(X, y, model, args) + else: + perform_training(X, y, model, args) \ No newline at end of file diff --git a/learner/util/save_load.py b/learner/util/save_load.py index 255d8152..0ac282f6 100644 --- a/learner/util/save_load.py +++ b/learner/util/save_load.py @@ -1,5 +1,7 @@ +""" Module for dealing with model saving and loading. """ import os import torch +import joblib import datetime import representation from argparse import Namespace as Args @@ -7,105 +9,125 @@ from gnns.base_gnn import BasePredictor as GNN from gnns import * -""" Module for dealing with model saving and loading. """ + +_TRAINED_MODELS_SAVE_DIR = "trained_models" +os.makedirs(_TRAINED_MODELS_SAVE_DIR, exist_ok=True) def arg_to_params(args, in_feat=4, out_feat=1): - model_name = args.model - nlayers = args.nlayers - nhid = args.nhid - in_feat = args.in_feat - n_edge_labels = args.n_edge_labels - share_layers = args.share_layers - task = args.task - pool = args.pool - aggr = args.aggr - vn = args.vn - rep = args.rep - model_params = { - 'model_name': model_name, - 'in_feat': in_feat, - 'out_feat': out_feat, - 'nlayers': nlayers, - 'share_layers': share_layers, - 'n_edge_labels': n_edge_labels, - 'nhid': nhid, - 'aggr': aggr, - 'pool': pool, - 'task': task, - 'rep': rep, - 'vn': vn, - } - return model_params + model_name = args.model + nlayers = args.nlayers + nhid = args.nhid + in_feat = args.in_feat + n_edge_labels = args.n_edge_labels + share_layers = args.share_layers + task = args.task + pool = args.pool + aggr = args.aggr + vn = args.vn + rep = args.rep + model_params = { + 'model_name': model_name, + 'in_feat': in_feat, + 'out_feat': out_feat, + 'nlayers': nlayers, + 'share_layers': share_layers, + 'n_edge_labels': n_edge_labels, + 'nhid': nhid, + 'aggr': aggr, + 'pool': pool, + 'task': task, + 'rep': rep, + 'vn': vn, + } + return model_params def print_arguments(args, ignore_params=set()): - if hasattr(args, 'pretrained') and args.pretrained is not None: - return - print("Parsed arguments:") - for k, v in vars(args).items(): - if k in ignore_params.union({"device", "optimal", "save_model", "save_file", "no_tqdm", "tqdm", "fast_train"}): - continue - print('{0:20} {1}'.format(k, v)) + if hasattr(args, 'pretrained') and args.pretrained is not None: + return + print("Parsed arguments:") + for k, v in vars(args).items(): + if k in ignore_params.union({"device", "optimal", "save_model", "save_file", "no_tqdm", "tqdm", "fast_train"}): + continue + print('{0:20} {1}'.format(k, v)) def save_model_from_dict(model_dict, args): - if not hasattr(args, "save_file") or args.save_file is None: - return - print("Saving model...") - save_dir = 'trained_models' - os.makedirs(f"{save_dir}/", exist_ok=True) - model_file_name = args.save_file.replace(".dt", "") - path = f'{save_dir}/{model_file_name}.dt' - torch.save((model_dict, args), path) - print("Model saved!") - print("Model parameter file:") - print(model_file_name) + if not hasattr(args, "save_file") or args.save_file is None: return + print("Saving model...") + model_file_name = args.save_file.replace(".dt", "") + path = f'{_TRAINED_MODELS_SAVE_DIR}/{model_file_name}.dt' + torch.save((model_dict, args), path) + print("Model saved!") + print("Model parameter file:") + print(model_file_name) + return def save_model(model, args): - save_model_from_dict(model.model.state_dict(), args) + save_model_from_dict(model.model.state_dict(), args) + return + + +def save_sklearn_model(model, args): + if not hasattr(args, "save_file") or args.save_file is None: return + print("Saving model...") + model_file_name = args.save_file.replace(".joblib", "") + path = f'{_TRAINED_MODELS_SAVE_DIR}/{model_file_name}.joblib' + joblib.dump((model, args), path) + print("Model saved!") + print("Model parameter file:") + print(model_file_name) + return + + +def load_sklearn_model(path, ignore_subdir=False): + if not ignore_subdir and _TRAINED_MODELS_SAVE_DIR not in path: + path = _TRAINED_MODELS_SAVE_DIR + "/" + path + model, args = joblib.load(path) + return model, args def load_model(path, print_args=False, jit=False, ignore_subdir=False) -> Tuple[GNN, Args]: - print("Loading model...") - assert ".pt" not in path, f"Found .pt in path {path}" - if ".dt" not in path: - path = path+".dt" - if not ignore_subdir and "trained_models" not in path: - path = "trained_models/" + path - try: - if torch.cuda.is_available(): - model_state_dict, args = torch.load(path) - else: - model_state_dict, args = torch.load(path, map_location=torch.device('cpu')) - except: - print(f"Model not found at {path}") - exit(-1) - # update legacy naming - if "dg-el" in args.rep: - args.rep = args.rep.replace("dg-el", "lg") - model = GNNS[args.model](params=arg_to_params(args), jit=jit) - model.load_state_dict_into_gnn(model_state_dict) - print("Model loaded!") - if print_args: - print_arguments(args) - model.eval() - return model, args + print("Loading model...") + assert ".pt" not in path, f"Found .pt in path {path}" + if ".dt" not in path: + path = path+".dt" + if not ignore_subdir and _TRAINED_MODELS_SAVE_DIR not in path: + path = _TRAINED_MODELS_SAVE_DIR + "/" + path + try: + if torch.cuda.is_available(): + model_state_dict, args = torch.load(path) + else: + model_state_dict, args = torch.load(path, map_location=torch.device('cpu')) + except: + print(f"Model not found at {path}") + exit(-1) + # update legacy naming + if "dg-el" in args.rep: + args.rep = args.rep.replace("dg-el", "lg") + model = GNNS[args.model](params=arg_to_params(args), jit=jit) + model.load_state_dict_into_gnn(model_state_dict) + print("Model loaded!") + if print_args: + print_arguments(args) + model.eval() + return model, args def load_model_and_setup_gnn(path, domain_file, problem_file): - model, args = load_model(path, ignore_subdir=True) - device = "cuda" if torch.cuda.is_available() else "cpu" - model = model.to(device) - model.batch_search(True) - model.update_representation(domain_pddl=domain_file, - problem_pddl=problem_file, - args=args, - device=device) - model.set_zero_grad() - model.eval() - return model - + model, args = load_model(path, ignore_subdir=True) + device = "cuda" if torch.cuda.is_available() else "cpu" + model = model.to(device) + model.batch_search(True) + model.update_representation(domain_pddl=domain_file, + problem_pddl=problem_file, + args=args, + device=device) + model.set_zero_grad() + model.eval() + return model + \ No newline at end of file diff --git a/learner/util/search.py b/learner/util/search.py index 118f5324..89582f01 100644 --- a/learner/util/search.py +++ b/learner/util/search.py @@ -64,25 +64,10 @@ def fd_cmd(rep, df, pf, m, search, seed, timeout=TIMEOUT): else: raise NotImplementedError - # A hack given that FD FeaturePlugin cannot parse strings - # 0: slg, 1: flg, 2: dlg, 3: llg - assert rep in REPRESENTATIONS - config_file = rep - config = { - "slg":0, - "flg":1, - "dlg":2, - "llg":3, - }[rep] - description = f"fd_{pf.replace('.pddl','').replace('/','-')}_{search}_{os.path.basename(m).replace('.dt', '')}" sas_file = f"sas_files/{description}.sas_file" plan_file = f"plans/{description}.plan" - with open(config_file, 'w') as f: - f.write(m+'\n') - f.write(df+'\n') - f.write(pf+'\n') - f.close() - cmd = f'./../downward/fast-downward.py --search-time-limit {timeout} --sas-file {sas_file} --plan-file {plan_file} {df} {pf} --search "{search}([goose(graph={config})])"' + cmd = f"./../downward/fast-downward.py --search-time-limit {timeout} --sas-file {sas_file} --plan-file {plan_file} "+\ + f"{df} {pf} --search '{search}([goose(model_path=\"{m}\", domain_file=\"{df}\", instance_file=\"{pf}\")])'" cmd = f"export GOOSE={os.getcwd()} && {cmd}" return cmd, sas_file diff --git a/learner/util/stats.py b/learner/util/stats.py index dc803c4d..065f0130 100644 --- a/learner/util/stats.py +++ b/learner/util/stats.py @@ -92,205 +92,3 @@ def get_stats(dataset, desc=""): print_quartiles("density:", graph_dense, floats=True) return - - -def view_confusion_matrix(plt_title, y_pred, y_true, view_cm, alt_save="", cutoff=-1, fontsize=None, removeaxeslabel=False): - if fontsize is not None: - plt.rcParams.update({'font.size': fontsize}) - y_pred = [round(i) for i in y_pred] - y_true = [round(i) for i in y_true] - # min_true = min(y_true) - # y_pred = y_pred + list(range(min_true)) - # y_true = y_true + list(range(min_true)) - fig, ax = plt.subplots(figsize=(10, 10)) - if cutoff == -1: - cutoff = max(y_true)+1 - cm = confusion_matrix(y_true, y_pred, normalize="true", labels=list(range(0,cutoff))) - display_labels = None - max_y = cm.shape[0] - if max_y >= 50: - display_labels = [] - for y in range(max_y): - if y % 10 == 0: - display_labels.append(y) - else: - display_labels.append("") - disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=display_labels) - disp.plot(include_values=False, xticks_rotation="vertical", ax=ax, colorbar=False, cmap=plt.cm.Blues) - disp.im_.set_clim(0, 1) - plt_title = str(plt_title) - # plt.title(plt_title) - plt_title = ' '.join(plt_title.split()) - plt_title = plt_title.replace(" ", "_") - plt.axis("off") - if removeaxeslabel: - plt.gca().xaxis.label.set_visible(False) - plt.gca().yaxis.label.set_visible(False) - if alt_save != "": - # alt_save = alt_save.replace(".pdf", "") - # alt_save = alt_save.replace(".png", "") - plt.savefig(f"{alt_save}", bbox_inches="tight") - else: - plt_title = plt_title.replace(".pdf", "") - plt_title = plt_title.replace(".png", "") - plt.savefig(f"plots/{plt_title}.pdf") - if view_cm: - print(f"Showing {plt_title}") - plt.show() - plt.clf() - return - - -@torch.no_grad() -def visualise_loader_stats(model, device, loader, title): - # visualise_train_stats so disgusting so just make another one here - model.eval() - y_true = torch.tensor([]) - y_pred = torch.tensor([]) - for data in tqdm(loader): - data = data.to(device) - y = data.y - out = model.forward(data) - - y_pred = torch.cat((y_pred, out.detach().cpu())) - y_true = torch.cat((y_true, y.detach().cpu())) - - loss = torch.nn.MSELoss()(y_pred, y_true) - macro_f1, micro_f1 = eval_f1_score(y_pred=y_pred, y_true=y_true) - admis = eval_admissibility(y_pred=y_pred, y_true=y_true) - print(f"size: {len(y_true)}") - print(f"loss: {loss:.2f}") - print(f"f1: {macro_f1:.1f}") - print(f"admissibility: {admis:.1f}") - title = f"{title} f1={macro_f1:.1f} loss={loss:.2f}" - view_confusion_matrix(title, y_pred.tolist(), y_true.tolist(), view_cm=True) - return - - -@torch.no_grad() -def visualise_train_stats(model, device, train_loader, val_loader=None, test_loader=None, max_cost=20, print_stats=True, - classify=False, view_cm=False, cm_train="cm_train", cm_val="cm_val", cm_test="cm_test"): - model = model.to(device) - model.eval() - - def get_stats_from_loader(loader): - preds = [] - true = [] - errors = [[] for _ in range(max_cost+1)] - for batch in tqdm(loader): - batch = batch.to(device) - y = batch.y - out = model.forward(batch) - if classify: - out = torch.argmax(out, dim=1) - else: - out = torch.maximum(out, torch.zeros_like(out)) # so h is nonzero - batch_errors = (y - out) / y - for i in range(len(y)): - e = batch_errors[i].detach().cpu().item() - c = y[i].detach().cpu().item() - o = out[i].detach().cpu().item() - preds.append(round(o)) - true.append(c) - errors[0].append(e) - if c <= max_cost: - errors[round(c)].append(c - o) - errors[0] = np.array(errors[0]) - errors[0][np.isnan(errors[0])] = 0 - preds = np.array(preds) - true = np.array(true) - return preds, true, errors - - print("Collecting stats...") - - # print("Prediction value set", np.unique(train_preds, return_counts=True)) - os.makedirs("plots", exist_ok=True) - for fname in ["error_prop", "preds_train", "error_train", "preds_val", "error_val", "preds_test", "error_test"]: - try: - os.remove(f"plots/{fname}.png") - except: - pass - - boxes = [] - ticks = [] - - if train_loader is not None: - train_preds, train_true, train_errors = get_stats_from_loader(train_loader) - view_confusion_matrix(plt_title=cm_train, y_true=train_true, y_pred=train_preds, view_cm=view_cm) - # boxes.append(train_errors[0]) - # ticks.append((len(boxes), 'train')) - # plt.hist(train_preds, bins=round(np.max(train_preds) + 1), - # range=(0, round(np.max(train_preds) + 1))) - # plt.title('Train prediction distribution') - # plt.savefig('plots/preds_train', dpi=480) - # plt.clf() - # - # plt.boxplot([train_errors[i] for i in range(1, max_cost + 1)]) - # plt.title('Train error differences over states away from target') - # plt.ylim((-4, 4)) - # plt.tight_layout() - # plt.savefig('plots/error_train', dpi=480) - # plt.clf() - if val_loader is not None: - val_preds, val_true, val_errors = get_stats_from_loader(val_loader) - view_confusion_matrix(plt_title=cm_val, y_true=val_true, y_pred=val_preds, view_cm=view_cm) - # boxes.append(val_errors[0]) - # ticks.append((len(boxes), 'val')) - # plt.hist(val_preds, bins=round(np.max(val_preds) + 1), - # range=(0, round(np.max(val_preds) + 1))) - # plt.title('Validation prediction distribution') - # plt.savefig('plots/preds_val', dpi=480) - # plt.clf() - # - # plt.boxplot([val_errors[i] for i in range(1, max_cost + 1)]) - # plt.title('Val error differences over states away from target') - # plt.ylim((-4, 4)) - # plt.tight_layout() - # plt.savefig('plots/error_val', dpi=480) - # plt.clf() - if test_loader is not None: - test_preds, test_true, test_errors = get_stats_from_loader(test_loader) - view_confusion_matrix(plt_title=cm_test, y_true=test_true, y_pred=test_preds, view_cm=view_cm) - # boxes.append(test_errors[0]) - # ticks.append((len(boxes), 'test')) - # plt.hist(test_preds, bins=round(np.max(test_preds) + 1), - # range=(0, round(np.max(test_preds) + 1))) - # plt.title('Test prediction distribution') - # plt.savefig('plots/preds_val', dpi=480) - # plt.clf() - # - # plt.boxplot([test_errors[i] for i in range(1, max_cost + 1)]) - # plt.title('Test error differences over states away from target') - # plt.ylim((-4, 4)) - # plt.tight_layout() - # plt.savefig('plots/error_test', dpi=480) - # plt.clf() - - print("Plotting done!") - - # Statistics - if print_stats: - print("{0:<20} {1:>10} {2:>10} {3:>10} {4:>10} {5:>10}".format(" ", "Q1", "median", "Q3", "min", "max")) - if train_loader is not None: - print_quartiles("train prop_err:", train_errors[0], floats=True) - if val_loader is not None: - print_quartiles("val prop_err:", val_errors[0], floats=True) - if test_loader is not None: - print_quartiles("test prop_err:", test_errors[0], floats=True) - print("% admissible") - if train_loader is not None: - print(f"train: {np.count_nonzero(train_errors[0] > 0) / len(train_errors[0]):.2f}") - if val_loader is not None: - print(f"val: {np.count_nonzero(val_errors[0] > 0) / len(val_errors[0]):.2f}") - if test_loader is not None: - print(f"test: {np.count_nonzero(test_errors[0] > 0) / len(test_errors[0])}:.2f") - - # plt.boxplot(boxes) - # plt.xticks(ticks) - # plt.ylim((-1, 1)) - # plt.title('Proportion errors') - # plt.tight_layout() - # plt.savefig('plots/error_prop', dpi=480) - # plt.clf() - - return diff --git a/learner/util/visualise.py b/learner/util/visualise.py index 58ae8ffe..65329019 100644 --- a/learner/util/visualise.py +++ b/learner/util/visualise.py @@ -1,6 +1,8 @@ import os import sys +from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix + sys.path.append(os.path.join(os.path.dirname(__file__), "..")) import re @@ -320,3 +322,20 @@ def display_solved_test_stats(train_type, L, H, aggr, p): def get_max_of_parameters(df): df = df.drop(columns=["L", "aggr"]).max() return df + +def get_confusion_matrix(y_true_train, y_pred_train, y_true_test, y_pred_test, cutoff=-1): + y_true_train = np.rint(y_true_train).astype(int) + y_pred_train = np.rint(y_pred_train).astype(int) + y_true_test = np.rint(y_true_test).astype(int) + y_pred_test = np.rint(y_pred_test).astype(int) + fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 10)) + if cutoff == -1: + cutoff = max(max(y_true_train), max(y_true_test))+1 + cm_train = confusion_matrix(y_true_train, y_pred_train, normalize="true", labels=list(range(0, cutoff))) + cm_test = confusion_matrix(y_true_test, y_pred_test, normalize="true", labels=list(range(0, cutoff))) + display_labels = [y if y%10==0 else "" for y in range(cutoff)] + for i, cm in enumerate([cm_train, cm_test]): + disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=display_labels) + disp.plot(include_values=False, xticks_rotation="vertical", ax=ax[i], colorbar=False) + disp.im_.set_clim(0, 1) + return plt diff --git a/requirements.txt b/requirements.txt index 778299e0..8d9d1b40 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ numpy==1.25.0 pandas==1.5.2 plotly==5.15.0 pytest==7.4.0 -scikit_learn==1.2.0 +scikit_learn==1.3.0 scipy==1.9.3 seaborn==0.12.2 torch==2.0.1