From f8406c56ecf9a955d102ca7bbf1f4a4dffeb90d4 Mon Sep 17 00:00:00 2001 From: Spyros Date: Sun, 11 Aug 2024 21:20:59 +0100 Subject: [PATCH] make the memory graph self-initialize, move functionality of get_standard_by_db_id to get_nodes --- application/database/db.py | 64 ++++++++----------- application/database/inmemory_graph.py | 13 ++-- application/prompt_client/prompt_client.py | 4 +- .../parsers/cloud_native_security_controls.py | 2 +- .../parsers/juiceshop.py | 2 +- .../parsers/pci_dss.py | 2 +- 6 files changed, 41 insertions(+), 46 deletions(-) diff --git a/application/database/db.py b/application/database/db.py index f579d9e6..de2b24f1 100644 --- a/application/database/db.py +++ b/application/database/db.py @@ -742,13 +742,11 @@ def __get_all_nodes_and_cres(self) -> List[cre_defs.Document]: cres = [] node_ids = self.session.query(Node.id).all() for nid in node_ids: - nodes.append(self.get_node_by_db_id(nid[0])) - [result.append(node) for node in nodes] + result.extend(self.get_nodes(db_id=nid[0])) cre_ids = self.session.query(CRE.id).all() for cid in cre_ids: - cres.append(self.get_cre_by_db_id(cid[0])) - [result.append(cre) for cre in cres] + result.append(self.get_cre_by_db_id(cid[0])) return result def __introduces_cycle(self, node_from: str, node_to: str) -> Any: @@ -983,19 +981,24 @@ def get_nodes( description: Optional[str] = None, ntype: str = cre_defs.Standard.__name__, sectionID: Optional[str] = None, + db_id: Optional[str] = None, ) -> Optional[List[cre_defs.Node]]: nodes = [] - nodes_query = self.__get_nodes_query__( - name=name, - section=section, - subsection=subsection, - link=link, - version=version, - partial=partial, - ntype=ntype, - description=description, - sectionID=sectionID, - ) + nodes_query = None + if db_id: + nodes_query = self.session.query(Node).filter(Node.id == db_id) + else: + nodes_query = self.__get_nodes_query__( + name=name, + section=section, + subsection=subsection, + link=link, + version=version, + partial=partial, + ntype=ntype, + description=description, + sectionID=sectionID, + ) dbnodes = nodes_query.all() if dbnodes: for dbnode in dbnodes: @@ -1032,22 +1035,6 @@ def get_nodes( return [] - def get_node_by_db_id(self, id: str) -> cre_defs.Node: - node = self.session.query(Node).filter(Node.id == id).first() - if not node: - logger.error(f"Node {id} does not exist in the db") - return None - - cs = linked_cres = Links.query.filter(Links.node == id).all() - nodes = self.get_nodes( - name=node.name, - section=node.section, - subsection=node.subsection, - ntype=node.ntype, - sectionID=node.section_id, - )[0] - return nodes - def get_cre_by_db_id(self, id: str) -> cre_defs.CRE: """internal method, returns a shallow cre (no links) by its database id @@ -1194,12 +1181,13 @@ def get_CREs( for ls in linked_nodes: nd = self.session.query(Node).filter(Node.id == ls.node).first() if not include_only or (include_only and nd.name in include_only): - cre.add_link( - cre_defs.Link( - document=nodeFromDB(nd), - ltype=cre_defs.LinkTypes.from_str(ls.type), + n = nodeFromDB(nd) + if not cre.link_exists(n): + cre.add_link( + cre_defs.Link( + document=n, ltype=cre_defs.LinkTypes.from_str(ls.type) + ) ) - ) # TODO figure the query to merge the following two internal_links = ( self.session.query(InternalLinks) @@ -1231,7 +1219,9 @@ def get_CREs( elif il.group == dbcre.id: res = q.filter(CRE.id == il.cre).first() ltype = cre_defs.LinkTypes.from_str(il.type) - cre.add_link(cre_defs.Link(document=CREfromDB(res), ltype=ltype)) + c = CREfromDB(res) + if not cre.link_exists(c): + cre.add_link(cre_defs.Link(document=c, ltype=ltype)) cres.append(cre) return cres diff --git a/application/database/inmemory_graph.py b/application/database/inmemory_graph.py index ecafb97d..c8c4f1e8 100644 --- a/application/database/inmemory_graph.py +++ b/application/database/inmemory_graph.py @@ -14,6 +14,7 @@ class CRE_Graph: def instance(cls, documents: List[defs.Document] = None) -> "CRE_Graph": if cls.__instance is None: cls.__instance = cls.__new__(cls) + cls.graph = nx.DiGraph() cls.graph = cls.__load_cre_graph(documents=documents) return cls.__instance @@ -87,7 +88,7 @@ def get_path(self, start: str, end: str) -> List[Tuple[str, str]]: @classmethod def add_cre(cls, dbcre: defs.CRE, graph: nx.DiGraph) -> nx.DiGraph: if dbcre: - graph.add_node(f"CRE: {dbcre.id}", internal_id=dbcre.id) + cls.graph.add_node(f"CRE: {dbcre.id}", internal_id=dbcre.id) else: logger.error("Called with dbcre being none") return graph @@ -95,7 +96,7 @@ def add_cre(cls, dbcre: defs.CRE, graph: nx.DiGraph) -> nx.DiGraph: @classmethod def add_dbnode(cls, dbnode: defs.Node, graph: nx.DiGraph) -> nx.DiGraph: if dbnode: - graph.add_node( + cls.graph.add_node( "Node: " + str(dbnode.id), internal_id=dbnode.id, ) @@ -105,7 +106,10 @@ def add_dbnode(cls, dbnode: defs.Node, graph: nx.DiGraph) -> nx.DiGraph: @classmethod def __load_cre_graph(cls, documents: List[defs.Document]) -> nx.Graph: - graph = nx.DiGraph() + graph = cls.graph + if not graph: + graph = nx.DiGraph() + for doc in documents: from_doctype = None if doc.doctype == defs.Credoctypes.CRE: @@ -122,9 +126,10 @@ def __load_cre_graph(cls, documents: List[defs.Document]) -> nx.Graph: else: graph = cls.add_dbnode(dbnode=link.document, graph=graph) to_doctype = "Node" - graph = graph.add_edge( + graph.add_edge( f"{from_doctype}: {doc.id}", f"{to_doctype}: {link.document.id}", ltype=link.ltype, ) + cls.graph = graph return graph diff --git a/application/prompt_client/prompt_client.py b/application/prompt_client/prompt_client.py index 83b81193..c3838298 100644 --- a/application/prompt_client/prompt_client.py +++ b/application/prompt_client/prompt_client.py @@ -148,7 +148,7 @@ def generate_embeddings( logger.info(f"generating {len(missing_embeddings)} embeddings") for id in missing_embeddings: cre = database.get_cre_by_db_id(id) - node = database.get_node_by_db_id(id) + node = database.get_nodes(db_id=id) content = "" if node: if is_valid_url(node.hyperlink): @@ -464,7 +464,7 @@ def generate_text(self, prompt: str) -> Dict[str, str]: ) closest_object = None if closest_id: - closest_object = self.database.get_node_by_db_id(closest_id) + closest_object = self.database.get_nodes(db_id=closest_id) logger.info( f"The prompt {prompt}, was most similar to object \n{closest_object}\n, with similarity:{similarity}" diff --git a/application/utils/external_project_parsers/parsers/cloud_native_security_controls.py b/application/utils/external_project_parsers/parsers/cloud_native_security_controls.py index 1d492fa5..be216c27 100644 --- a/application/utils/external_project_parsers/parsers/cloud_native_security_controls.py +++ b/application/utils/external_project_parsers/parsers/cloud_native_security_controls.py @@ -64,7 +64,7 @@ def parse(self, cache: db.Node_collection, ph: prompt_client.PromptHandler): ) standard_id = ph.get_id_of_most_similar_node(cnsc_embeddings) if standard_id: - dbstandard = cache.get_node_by_db_id(standard_id) + dbstandard = cache.get_nodes(db_id=standard_id) logger.info( f"found an appropriate standard for Cloud Native Security Control {cnsc.section}:{cnsc.subsection}, it is: {dbstandard.name}:{dbstandard.section}" ) diff --git a/application/utils/external_project_parsers/parsers/juiceshop.py b/application/utils/external_project_parsers/parsers/juiceshop.py index 43225569..2f8cd21c 100644 --- a/application/utils/external_project_parsers/parsers/juiceshop.py +++ b/application/utils/external_project_parsers/parsers/juiceshop.py @@ -78,7 +78,7 @@ def parse( f"could not find an appropriate CRE for Juiceshop challenge {chal.section}, findings similarities with standards instead" ) standard_id = ph.get_id_of_most_similar_node(challenge_embeddings) - dbstandard = cache.get_node_by_db_id(standard_id) + dbstandard = cache.get_nodes(db_id=standard_id) logger.info( f"found an appropriate standard for Juiceshop challenge {chal.section}, it is: {dbstandard.section}" ) diff --git a/application/utils/external_project_parsers/parsers/pci_dss.py b/application/utils/external_project_parsers/parsers/pci_dss.py index bafc2091..8a451d41 100644 --- a/application/utils/external_project_parsers/parsers/pci_dss.py +++ b/application/utils/external_project_parsers/parsers/pci_dss.py @@ -79,7 +79,7 @@ def __parse( f"could not find an appropriate CRE for pci {pci_control.section}, findings similarities with standards instead" ) standard_id = prompt.get_id_of_most_similar_node(control_embeddings) - dbstandard = cache.get_node_by_db_id(standard_id) + dbstandard = cache.get_nodes(db_id=standard_id) logger.info( f"found an appropriate standard for pci {pci_control.section}, it is: {dbstandard.section}" )