diff --git a/megnet/models.py b/megnet/models.py index 712efab85..4db7215d3 100644 --- a/megnet/models.py +++ b/megnet/models.py @@ -150,10 +150,24 @@ def predict_structure(self, structure): structure: pymatgen structure or molecule Returns: - batch generator object + predicted target value """ - inp = self.graph_convertor.get_input(structure) - return self.target_scaler.inverse_transform(self.predict(inp).ravel(), len(structure)) + graph = self.graph_convertor.convert(structure) + return self.predict_graph(graph) + + def predict_graph(self, graph): + """ + Predict property from graph + + Args: + graph: a graph dictionary, see megnet.data.graph + + Returns: + predicted target value + + """ + inp = self.graph_convertor.graph_to_input(graph) + return self.target_scaler.inverse_transform(self.predict(inp).ravel(), len(graph['atom'])) def _create_generator(self, *args, **kwargs): if hasattr(self.graph_convertor, 'bond_convertor'):