Skip to content

Commit

Permalink
refactor CrystalGraphConverter._create_graph_fast
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Mar 11, 2024
1 parent e7eddc7 commit 5b1656e
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions chgnet/graph/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,17 +250,12 @@ def _create_graph_fast(
distance = np.ascontiguousarray(distance)
gc_saved = gc.get_threshold()
gc.set_threshold(0)
(
nodes,
directed_edges_list,
undirected_edges_list,
undirected_edges,
) = make_graph(
nodes, directed_edges, undirected_edges, undirected_edges = make_graph(
center_index, len(center_index), neighbor_index, image, distance, n_atoms
)
graph = Graph(nodes=nodes)
graph.directed_edges_list = directed_edges_list
graph.undirected_edges_list = undirected_edges_list
graph.directed_edges_list = directed_edges
graph.undirected_edges_list = undirected_edges
graph.undirected_edges = undirected_edges
gc.set_threshold(gc_saved[0])

Expand All @@ -282,7 +277,7 @@ def set_isolated_atom_response(
self.on_isolated_atoms = on_isolated_atoms
return

def as_dict(self) -> dict[str, float]:
def as_dict(self) -> dict[str, str | float]:
"""Save the args of the graph converter."""
return {
"atom_graph_cutoff": self.atom_graph_cutoff,
Expand All @@ -291,6 +286,6 @@ def as_dict(self) -> dict[str, float]:
}

@classmethod
def from_dict(cls, dict) -> CrystalGraphConverter:
def from_dict(cls, dct: dict) -> CrystalGraphConverter:
"""Create converter from dictionary."""
return CrystalGraphConverter(**dict)
return CrystalGraphConverter(**dct)

0 comments on commit 5b1656e

Please sign in to comment.