From 1b4df74177680e5bcc85cf802f5ebd2e82f50501 Mon Sep 17 00:00:00 2001 From: Chi Chen Date: Mon, 8 Apr 2019 13:56:48 -0700 Subject: [PATCH 1/3] add default convertors --- megnet/models.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/megnet/models.py b/megnet/models.py index 9b954f60c..bf1bb2a71 100644 --- a/megnet/models.py +++ b/megnet/models.py @@ -8,8 +8,8 @@ from keras.models import Model from megnet.callbacks import ModelCheckpointMAE, ManualStop from megnet.utils.general_utils import expand_1st -from megnet.data.graph import GraphBatchDistanceConvert, GraphBatchGenerator -from megnet.data.crystal import graphs2inputs +from megnet.data.graph import GraphBatchDistanceConvert, GraphBatchGenerator, GaussianDistance +from megnet.data.crystal import graphs2inputs, CrystalGraph import numpy as np from sklearn.preprocessing import StandardScaler import os @@ -345,6 +345,10 @@ def one_block(a, b, c, has_ff=True): model = Model(inputs=[x1, x2, x3, x4, x5, x6, x7], outputs=out) model.compile(Adam(lr), loss) + if graph_convertor is None: + graph_convertor = CrystalGraph() + if distance_convertor is None: + distance_convertor = GaussianDistance(np.linspace(0, 4, 100), 0.5) super(MEGNetModel, self).__init__( model=model, graph_convertor=graph_convertor, distance_convertor=distance_convertor) From 83a197a4330509b3a46644b26e5332a79be7717e Mon Sep 17 00:00:00 2001 From: Chi Chen Date: Mon, 8 Apr 2019 13:58:31 -0700 Subject: [PATCH 2/3] change to paper default params --- megnet/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megnet/models.py b/megnet/models.py index bf1bb2a71..b18a870b4 100644 --- a/megnet/models.py +++ b/megnet/models.py @@ -348,7 +348,7 @@ def one_block(a, b, c, has_ff=True): if graph_convertor is None: graph_convertor = CrystalGraph() if distance_convertor is None: - distance_convertor = GaussianDistance(np.linspace(0, 4, 100), 0.5) + distance_convertor = GaussianDistance(np.linspace(0, 5, 100), 0.5) super(MEGNetModel, self).__init__( model=model, graph_convertor=graph_convertor, distance_convertor=distance_convertor) From 8128de23ba80f55c978dd902154409ae0a3daab3 Mon Sep 17 00:00:00 2001 From: Chi Chen Date: Mon, 8 Apr 2019 14:09:08 -0700 Subject: [PATCH 3/3] remove state-dependent standardscaler --- megnet/models.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/megnet/models.py b/megnet/models.py index b18a870b4..72dc90d6a 100644 --- a/megnet/models.py +++ b/megnet/models.py @@ -11,7 +11,6 @@ from megnet.data.graph import GraphBatchDistanceConvert, GraphBatchGenerator, GaussianDistance from megnet.data.crystal import graphs2inputs, CrystalGraph import numpy as np -from sklearn.preprocessing import StandardScaler import os from monty.serialization import dumpfn, loadfn @@ -41,7 +40,6 @@ def __init__(self, self.model = model self.graph_convertor = graph_convertor self.distance_convertor = distance_convertor - self.yscaler = StandardScaler() def __getattr__(self, p): return getattr(self.model, p) @@ -111,10 +109,10 @@ def train_from_graphs(self, os.makedirs(dirname) if callbacks is None: callbacks = [ManualStop()] - train_targets = self.yscaler.fit_transform(np.array(train_targets).reshape((-1, 1))).ravel() + train_targets = np.array(train_targets).ravel() if validation_graphs is not None: filepath = pjoin(dirname, 'val_mae_{epoch:05d}_{val_mae:.6f}.hdf5') - validation_targets = self.yscaler.transform(np.array(validation_targets).reshape((-1, 1))).ravel() + validation_targets = np.array(validation_targets).ravel() val_inputs = graphs2inputs(validation_graphs, validation_targets) val_generator = self._create_generator(*val_inputs, @@ -126,7 +124,7 @@ def train_from_graphs(self, save_weights_only=False, val_gen=val_generator, steps_per_val=steps_per_val, - y_scaler=self.yscaler)]) + y_scaler=None)]) else: val_generator = None steps_per_val = None @@ -154,7 +152,7 @@ def predict_structure(self, structure): expand_1st(np.array(gnode)), expand_1st(np.array(gbond)), ] - return self.yscaler.inverse_transform(self.predict(inp).reshape((-1, 1))).ravel() + return self.predict(inp).ravel() def _create_generator(self, *args, **kwargs): if self.distance_convertor is not None: