From 5b1656e31576af48eec4ac025b24ae64dc9eda52 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Mon, 11 Mar 2024 16:02:18 +0100 Subject: [PATCH] refactor CrystalGraphConverter._create_graph_fast --- chgnet/graph/converter.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/chgnet/graph/converter.py b/chgnet/graph/converter.py index 9a29a85..bd702ef 100644 --- a/chgnet/graph/converter.py +++ b/chgnet/graph/converter.py @@ -250,17 +250,12 @@ def _create_graph_fast( distance = np.ascontiguousarray(distance) gc_saved = gc.get_threshold() gc.set_threshold(0) - ( - nodes, - directed_edges_list, - undirected_edges_list, - undirected_edges, - ) = make_graph( + nodes, directed_edges, undirected_edges, undirected_edges = make_graph( center_index, len(center_index), neighbor_index, image, distance, n_atoms ) graph = Graph(nodes=nodes) - graph.directed_edges_list = directed_edges_list - graph.undirected_edges_list = undirected_edges_list + graph.directed_edges_list = directed_edges + graph.undirected_edges_list = undirected_edges graph.undirected_edges = undirected_edges gc.set_threshold(gc_saved[0]) @@ -282,7 +277,7 @@ def set_isolated_atom_response( self.on_isolated_atoms = on_isolated_atoms return - def as_dict(self) -> dict[str, float]: + def as_dict(self) -> dict[str, str | float]: """Save the args of the graph converter.""" return { "atom_graph_cutoff": self.atom_graph_cutoff, @@ -291,6 +286,6 @@ def as_dict(self) -> dict[str, float]: } @classmethod - def from_dict(cls, dict) -> CrystalGraphConverter: + def from_dict(cls, dct: dict) -> CrystalGraphConverter: """Create converter from dictionary.""" - return CrystalGraphConverter(**dict) + return CrystalGraphConverter(**dct)