Skip to content
This repository has been archived by the owner on Apr 27, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' of github.com:materialsvirtuallab/megnet
Browse files Browse the repository at this point in the history
  • Loading branch information
shyuep committed Apr 8, 2019
2 parents efd6105 + 8128de2 commit 313ef9e
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions megnet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from keras.models import Model
from megnet.callbacks import ModelCheckpointMAE, ManualStop
from megnet.utils.general_utils import expand_1st
from megnet.data.graph import GraphBatchDistanceConvert, GraphBatchGenerator
from megnet.data.crystal import graphs2inputs
from megnet.data.graph import GraphBatchDistanceConvert, GraphBatchGenerator, GaussianDistance
from megnet.data.crystal import graphs2inputs, CrystalGraph
import numpy as np
from sklearn.preprocessing import StandardScaler
import os

from monty.serialization import dumpfn, loadfn
Expand Down Expand Up @@ -41,7 +40,6 @@ def __init__(self,
self.model = model
self.graph_convertor = graph_convertor
self.distance_convertor = distance_convertor
self.yscaler = StandardScaler()

def __getattr__(self, p):
return getattr(self.model, p)
Expand Down Expand Up @@ -111,10 +109,10 @@ def train_from_graphs(self,
os.makedirs(dirname)
if callbacks is None:
callbacks = [ManualStop()]
train_targets = self.yscaler.fit_transform(np.array(train_targets).reshape((-1, 1))).ravel()
train_targets = np.array(train_targets).ravel()
if validation_graphs is not None:
filepath = pjoin(dirname, 'val_mae_{epoch:05d}_{val_mae:.6f}.hdf5')
validation_targets = self.yscaler.transform(np.array(validation_targets).reshape((-1, 1))).ravel()
validation_targets = np.array(validation_targets).ravel()
val_inputs = graphs2inputs(validation_graphs, validation_targets)

val_generator = self._create_generator(*val_inputs,
Expand All @@ -126,7 +124,7 @@ def train_from_graphs(self,
save_weights_only=False,
val_gen=val_generator,
steps_per_val=steps_per_val,
y_scaler=self.yscaler)])
y_scaler=None)])
else:
val_generator = None
steps_per_val = None
Expand Down Expand Up @@ -154,7 +152,7 @@ def predict_structure(self, structure):
expand_1st(np.array(gnode)),
expand_1st(np.array(gbond)),
]
return self.yscaler.inverse_transform(self.predict(inp).reshape((-1, 1))).ravel()
return self.predict(inp).ravel()

def _create_generator(self, *args, **kwargs):
if self.distance_convertor is not None:
Expand Down Expand Up @@ -345,6 +343,10 @@ def one_block(a, b, c, has_ff=True):
model = Model(inputs=[x1, x2, x3, x4, x5, x6, x7], outputs=out)
model.compile(Adam(lr), loss)

if graph_convertor is None:
graph_convertor = CrystalGraph()
if distance_convertor is None:
distance_convertor = GaussianDistance(np.linspace(0, 5, 100), 0.5)
super(MEGNetModel, self).__init__(
model=model, graph_convertor=graph_convertor,
distance_convertor=distance_convertor)

0 comments on commit 313ef9e

Please sign in to comment.