Skip to content
This repository has been archived by the owner on Apr 27, 2023. It is now read-only.

Commit

Permalink
remove state-dependent standardscaler
Browse files Browse the repository at this point in the history
  • Loading branch information
Chi Chen committed Apr 8, 2019
1 parent 83a197a commit 8128de2
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions megnet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from megnet.data.graph import GraphBatchDistanceConvert, GraphBatchGenerator, GaussianDistance
from megnet.data.crystal import graphs2inputs, CrystalGraph
import numpy as np
from sklearn.preprocessing import StandardScaler
import os

from monty.serialization import dumpfn, loadfn
Expand Down Expand Up @@ -41,7 +40,6 @@ def __init__(self,
self.model = model
self.graph_convertor = graph_convertor
self.distance_convertor = distance_convertor
self.yscaler = StandardScaler()

def __getattr__(self, p):
return getattr(self.model, p)
Expand Down Expand Up @@ -111,10 +109,10 @@ def train_from_graphs(self,
os.makedirs(dirname)
if callbacks is None:
callbacks = [ManualStop()]
train_targets = self.yscaler.fit_transform(np.array(train_targets).reshape((-1, 1))).ravel()
train_targets = np.array(train_targets).ravel()
if validation_graphs is not None:
filepath = pjoin(dirname, 'val_mae_{epoch:05d}_{val_mae:.6f}.hdf5')
validation_targets = self.yscaler.transform(np.array(validation_targets).reshape((-1, 1))).ravel()
validation_targets = np.array(validation_targets).ravel()
val_inputs = graphs2inputs(validation_graphs, validation_targets)

val_generator = self._create_generator(*val_inputs,
Expand All @@ -126,7 +124,7 @@ def train_from_graphs(self,
save_weights_only=False,
val_gen=val_generator,
steps_per_val=steps_per_val,
y_scaler=self.yscaler)])
y_scaler=None)])
else:
val_generator = None
steps_per_val = None
Expand Down Expand Up @@ -154,7 +152,7 @@ def predict_structure(self, structure):
expand_1st(np.array(gnode)),
expand_1st(np.array(gbond)),
]
return self.yscaler.inverse_transform(self.predict(inp).reshape((-1, 1))).ravel()
return self.predict(inp).ravel()

def _create_generator(self, *args, **kwargs):
if self.distance_convertor is not None:
Expand Down

0 comments on commit 8128de2

Please sign in to comment.