From f8bbc431a86c7ae9c9f0d3047351451c99bb1a86 Mon Sep 17 00:00:00 2001 From: Chi Chen Date: Mon, 20 May 2019 13:30:29 -0700 Subject: [PATCH] fix multi target models --- megnet/data/graph.py | 7 +++---- megnet/models.py | 6 ++---- megnet/tests/test_models.py | 11 +++++++++++ megnet/utils/preprocessing.py | 16 ++++++++++++++++ 4 files changed, 32 insertions(+), 8 deletions(-) diff --git a/megnet/data/graph.py b/megnet/data/graph.py index 687af75cc..ba0804b27 100644 --- a/megnet/data/graph.py +++ b/megnet/data/graph.py @@ -1,6 +1,6 @@ from operator import itemgetter import numpy as np -from megnet.utils.general_utils import expand_1st +from megnet.utils.general_utils import expand_1st, to_list from monty.json import MSONable from megnet.data import local_env from inspect import signature @@ -115,7 +115,7 @@ def get_flat_data(self, graphs, targets): Args: graphs: (list of dictionary) list of graph dictionary for each structure - targets: (list of float) correpsonding target values for each structure + targets: (list of float or list) correpsonding target values for each structure Returns: tuple(node_features, edges_features, global_values, index1, index2, targets) @@ -125,7 +125,6 @@ def get_flat_data(self, graphs, targets): states = [] index1 = [] index2 = [] - final_targets = [] for g, t in zip(graphs, targets): if isinstance(g, dict): @@ -134,7 +133,7 @@ def get_flat_data(self, graphs, targets): states.append(g['state']) index1.append(g['index1']) index2.append(g['index2']) - final_targets.append([t]) + final_targets.append(to_list(t)) return atoms, bonds, states, index1, index2, final_targets def _get_dummy_convertor(self): diff --git a/megnet/models.py b/megnet/models.py index 9b265e7f9..d6d46b32f 100644 --- a/megnet/models.py +++ b/megnet/models.py @@ -8,7 +8,7 @@ from megnet.callbacks import ModelCheckpointMAE, ManualStop, ReduceLRUponNan from megnet.data.graph import GraphBatchDistanceConvert, GraphBatchGenerator, GaussianDistance from megnet.data.crystal import CrystalGraph -from megnet.utils.preprocessing import StandardScaler +from megnet.utils.preprocessing import DummyScaler import numpy as np import os from warnings import warn @@ -40,7 +40,7 @@ class GraphModel: def __init__(self, model, graph_convertor, - target_scaler=StandardScaler(mean=0, std=1, is_intensive=True), + target_scaler=DummyScaler(), metadata=None, **kwargs): self.model = model @@ -125,13 +125,11 @@ def train_from_graphs(self, callbacks.append(ReduceLRUponNan()) train_nb_atoms = [len(i['atom']) for i in train_graphs] train_targets = [self.target_scaler.transform(i, j) for i, j in zip(train_targets, train_nb_atoms)] - train_targets = np.array(train_targets).ravel() if validation_graphs is not None: filepath = pjoin(dirname, 'val_mae_{epoch:05d}_{%s:.6f}.hdf5' % monitor) val_nb_atoms = [len(i['atom']) for i in validation_graphs] validation_targets = [self.target_scaler.transform(i, j) for i, j in zip(validation_targets, val_nb_atoms)] - validation_targets = np.array(validation_targets).ravel() val_inputs = self.graph_convertor.get_flat_data(validation_graphs, validation_targets) val_generator = self._create_generator(*val_inputs, diff --git a/megnet/tests/test_models.py b/megnet/tests/test_models.py index d80f68cfd..53bd77ac0 100644 --- a/megnet/tests/test_models.py +++ b/megnet/tests/test_models.py @@ -56,6 +56,10 @@ def __getitem__(self, index): n1=4, n2=4, n3=4, npass=1, ntarget=1, graph_convertor=CrystalGraph(bond_convertor=GaussianDistance(np.linspace(0, 5, 10), 0.5)), ) + cls.model2 = MEGNetModel(10, 2, nblocks=1, lr=1e-2, + n1=4, n2=4, n3=4, npass=1, ntarget=2, + graph_convertor=CrystalGraph(bond_convertor=GaussianDistance(np.linspace(0, 5, 10), 0.5)), + ) def test_train_pred(self): s = Structure.from_file(os.path.join(cwd, '../data/tests/cifs/BaTiO3_mp-2998_computed.cif')) @@ -100,6 +104,13 @@ def test_single_atom_structure(self): pred = self.model.predict_structure(s) self.assertEqual(len(pred.ravel()), 1) + def test_two_targets(self): + s = Structure(Lattice.cubic(3), ['Si'], [[0, 0, 0]]) + # initialize the model + self.model2.train([s, s], [[0.1, 0.2], [0.1, 0.2]], epochs=1) + pred = self.model2.predict_structure(s) + self.assertEqual(len(pred.ravel()), 2) + def test_save_and_load(self): weights1 = self.model.get_weights() with ScratchDir('.'): diff --git a/megnet/utils/preprocessing.py b/megnet/utils/preprocessing.py index 7d9517948..00c9379ab 100644 --- a/megnet/utils/preprocessing.py +++ b/megnet/utils/preprocessing.py @@ -70,3 +70,19 @@ def __str__(self): def __repr__(self): return str(self) + + +class DummyScaler(MSONable): + """ + Dummy scaler does nothing + """ + def transform(self, target, n=1): + return target + + def inverse_transform(self, transformed_target, n=1): + return transformed_target + + @classmethod + def from_training_data(cls, structures, targets, is_intensive=True): + return cls() +