Skip to content
This repository has been archived by the owner on Apr 27, 2023. It is now read-only.

Commit

Permalink
fix multi target models
Browse files Browse the repository at this point in the history
  • Loading branch information
Chi Chen committed May 20, 2019
1 parent 2658739 commit f8bbc43
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 8 deletions.
7 changes: 3 additions & 4 deletions megnet/data/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from operator import itemgetter
import numpy as np
from megnet.utils.general_utils import expand_1st
from megnet.utils.general_utils import expand_1st, to_list
from monty.json import MSONable
from megnet.data import local_env
from inspect import signature
Expand Down Expand Up @@ -115,7 +115,7 @@ def get_flat_data(self, graphs, targets):
Args:
graphs: (list of dictionary) list of graph dictionary for each structure
targets: (list of float) correpsonding target values for each structure
targets: (list of float or list) correpsonding target values for each structure
Returns:
tuple(node_features, edges_features, global_values, index1, index2, targets)
Expand All @@ -125,7 +125,6 @@ def get_flat_data(self, graphs, targets):
states = []
index1 = []
index2 = []

final_targets = []
for g, t in zip(graphs, targets):
if isinstance(g, dict):
Expand All @@ -134,7 +133,7 @@ def get_flat_data(self, graphs, targets):
states.append(g['state'])
index1.append(g['index1'])
index2.append(g['index2'])
final_targets.append([t])
final_targets.append(to_list(t))
return atoms, bonds, states, index1, index2, final_targets

def _get_dummy_convertor(self):
Expand Down
6 changes: 2 additions & 4 deletions megnet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from megnet.callbacks import ModelCheckpointMAE, ManualStop, ReduceLRUponNan
from megnet.data.graph import GraphBatchDistanceConvert, GraphBatchGenerator, GaussianDistance
from megnet.data.crystal import CrystalGraph
from megnet.utils.preprocessing import StandardScaler
from megnet.utils.preprocessing import DummyScaler
import numpy as np
import os
from warnings import warn
Expand Down Expand Up @@ -40,7 +40,7 @@ class GraphModel:
def __init__(self,
model,
graph_convertor,
target_scaler=StandardScaler(mean=0, std=1, is_intensive=True),
target_scaler=DummyScaler(),
metadata=None,
**kwargs):
self.model = model
Expand Down Expand Up @@ -125,13 +125,11 @@ def train_from_graphs(self,
callbacks.append(ReduceLRUponNan())
train_nb_atoms = [len(i['atom']) for i in train_graphs]
train_targets = [self.target_scaler.transform(i, j) for i, j in zip(train_targets, train_nb_atoms)]
train_targets = np.array(train_targets).ravel()

if validation_graphs is not None:
filepath = pjoin(dirname, 'val_mae_{epoch:05d}_{%s:.6f}.hdf5' % monitor)
val_nb_atoms = [len(i['atom']) for i in validation_graphs]
validation_targets = [self.target_scaler.transform(i, j) for i, j in zip(validation_targets, val_nb_atoms)]
validation_targets = np.array(validation_targets).ravel()
val_inputs = self.graph_convertor.get_flat_data(validation_graphs, validation_targets)

val_generator = self._create_generator(*val_inputs,
Expand Down
11 changes: 11 additions & 0 deletions megnet/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def __getitem__(self, index):
n1=4, n2=4, n3=4, npass=1, ntarget=1,
graph_convertor=CrystalGraph(bond_convertor=GaussianDistance(np.linspace(0, 5, 10), 0.5)),
)
cls.model2 = MEGNetModel(10, 2, nblocks=1, lr=1e-2,
n1=4, n2=4, n3=4, npass=1, ntarget=2,
graph_convertor=CrystalGraph(bond_convertor=GaussianDistance(np.linspace(0, 5, 10), 0.5)),
)

def test_train_pred(self):
s = Structure.from_file(os.path.join(cwd, '../data/tests/cifs/BaTiO3_mp-2998_computed.cif'))
Expand Down Expand Up @@ -100,6 +104,13 @@ def test_single_atom_structure(self):
pred = self.model.predict_structure(s)
self.assertEqual(len(pred.ravel()), 1)

def test_two_targets(self):
s = Structure(Lattice.cubic(3), ['Si'], [[0, 0, 0]])
# initialize the model
self.model2.train([s, s], [[0.1, 0.2], [0.1, 0.2]], epochs=1)
pred = self.model2.predict_structure(s)
self.assertEqual(len(pred.ravel()), 2)

def test_save_and_load(self):
weights1 = self.model.get_weights()
with ScratchDir('.'):
Expand Down
16 changes: 16 additions & 0 deletions megnet/utils/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,19 @@ def __str__(self):

def __repr__(self):
return str(self)


class DummyScaler(MSONable):
"""
Dummy scaler does nothing
"""
def transform(self, target, n=1):
return target

def inverse_transform(self, transformed_target, n=1):
return transformed_target

@classmethod
def from_training_data(cls, structures, targets, is_intensive=True):
return cls()

0 comments on commit f8bbc43

Please sign in to comment.