Skip to content

Commit

Permalink
refactor CrystalGraphConverter._create_graph_fast
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Mar 11, 2024
1 parent e7eddc7 commit f1b88ee
Showing 1 changed file with 7 additions and 13 deletions.
20 changes: 7 additions & 13 deletions chgnet/graph/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,25 +250,19 @@ def _create_graph_fast(
distance = np.ascontiguousarray(distance)
gc_saved = gc.get_threshold()
gc.set_threshold(0)
(
nodes,
directed_edges_list,
undirected_edges_list,
undirected_edges,
) = make_graph(
nodes, dir_edges_list, undir_edges_list, undirected_edges = make_graph(
center_index, len(center_index), neighbor_index, image, distance, n_atoms
)
graph = Graph(nodes=nodes)
graph.directed_edges_list = directed_edges_list
graph.undirected_edges_list = undirected_edges_list
graph.directed_edges_list = dir_edges_list
graph.undirected_edges_list = undir_edges_list
graph.undirected_edges = undirected_edges
gc.set_threshold(gc_saved[0])

return graph

def set_isolated_atom_response(
self,
on_isolated_atoms: Literal["ignore", "warn", "error"],
self, on_isolated_atoms: Literal["ignore", "warn", "error"]
) -> None:
"""Set the graph converter's response to isolated atom graph
Args:
Expand All @@ -282,7 +276,7 @@ def set_isolated_atom_response(
self.on_isolated_atoms = on_isolated_atoms
return

def as_dict(self) -> dict[str, float]:
def as_dict(self) -> dict[str, str | float]:
"""Save the args of the graph converter."""
return {
"atom_graph_cutoff": self.atom_graph_cutoff,
Expand All @@ -291,6 +285,6 @@ def as_dict(self) -> dict[str, float]:
}

@classmethod
def from_dict(cls, dict) -> CrystalGraphConverter:
def from_dict(cls, dct: dict) -> CrystalGraphConverter:
"""Create converter from dictionary."""
return CrystalGraphConverter(**dict)
return CrystalGraphConverter(**dct)

0 comments on commit f1b88ee

Please sign in to comment.