Skip to content
This repository has been archived by the owner on Apr 27, 2023. It is now read-only.

Commit

Permalink
add predict_graph method for MEGNetModel
Browse files Browse the repository at this point in the history
  • Loading branch information
Chi Chen committed May 15, 2019
1 parent 9567a6e commit 4fbd86e
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions megnet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,24 @@ def predict_structure(self, structure):
structure: pymatgen structure or molecule
Returns:
batch generator object
predicted target value
"""
inp = self.graph_convertor.get_input(structure)
return self.target_scaler.inverse_transform(self.predict(inp).ravel(), len(structure))
graph = self.graph_convertor.convert(structure)
return self.predict_graph(graph)

def predict_graph(self, graph):
"""
Predict property from graph
Args:
graph: a graph dictionary, see megnet.data.graph
Returns:
predicted target value
"""
inp = self.graph_convertor.graph_to_input(graph)
return self.target_scaler.inverse_transform(self.predict(inp).ravel(), len(graph['atom']))

def _create_generator(self, *args, **kwargs):
if hasattr(self.graph_convertor, 'bond_convertor'):
Expand Down

0 comments on commit 4fbd86e

Please sign in to comment.