diff --git a/megnet/data/crystal.py b/megnet/data/crystal.py index 8b9c89f29..f84450d4e 100644 --- a/megnet/data/crystal.py +++ b/megnet/data/crystal.py @@ -17,89 +17,85 @@ def __init__(self, r=4): self.r = r def convert(self, structure, state_attributes=None): - return structure2graph(structure, state_attributes, r=self.r) + """ + Take a pymatgen structure and convert it to a index-type graph representation + The graph will have node, distance, index1, index2, where node is a vector of Z number + of atoms in the structure, index1 and index2 mark the atom indices forming the bond and separated by + distance + :param structure: (pymatgen structure) + :param state_attributes: (list) state attributes + :return: (dictionary) + """ + atom_i_segment_id = [] # index list for the center atom i for all bonds (row index) + atom_i_j_id = [] # index list for atom j + atom_number = [] + all_neighbors = structure.get_all_neighbors(self.r, include_index=True) + distances = [] + state_attributes = state_attributes or [[0, 0]] + for k, n in enumerate(all_neighbors): + atom_number.append(structure.sites[k].specie.Z) + if len(n) < 1: + index = None + else: + _, distance, index = list(zip(*n)) + index = np.array(index) + distance = np.array(distance) + + if index is not None: + ind = np.argsort(index) + it = itemgetter(*ind) + index = it(index) + index = to_list(index) + index = [int(i) for i in index] + distance = distance[ind] + distances.append(distance) + atom_i_segment_id.extend([k] * len(index)) + atom_i_j_id.extend(index) + else: + pass + if len(distances) < 1: + return None + else: + return {'distance': np.concatenate(distances), + 'index1': atom_i_segment_id, + 'index2': atom_i_j_id, + 'node': atom_number, + 'state': state_attributes} + + def get_input(self, structure, distance_convertor=None, **kwargs): + """ + Take a pymatgen structure and convert it to a index-type graph + representation as model input + + :param structure: (pymatgen structure) + :param r: (float) cutoff radius + :param state_attributes: (list) a list of state attributes + :param distance_convertor: (object) convert numeric distance values + into a vector as bond features + :return: (dictionary) inputs for model.predict + """ + graph = self.convert(structure) + if distance_convertor is None: + centers = kwargs.get('centers', np.linspace(0, 6, 100)) + width = kwargs.get('width', 0.5) + distance_convertor = GaussianDistance(centers, width) + + gnode = [0] * len(structure) + gbond = [0] * len(graph['index1']) + + return [expand_1st(graph['node']), + expand_1st(distance_convertor.convert(graph['distance'])), + expand_1st(np.array(graph['state'])), + expand_1st(np.array(graph['index1'])), + expand_1st(np.array(graph['index2'])), + expand_1st(np.array(gnode)), + expand_1st(np.array(gbond)), + ] def __call__(self, structure, state_attributes=None): return self.convert(structure, state_attributes) -def structure2graph(structure, state_attributes=None, r=4): - """ - Take a pymatgen structure and convert it to a index-type graph representation - The graph will have node, distance, index1, index2, where node is a vector of Z number - of atoms in the structure, index1 and index2 mark the atom indices forming the bond and separated by - distance - :param structure: (pymatgen structure) - :param state_attributes: (list) state attributes - :param r: (float) distance cutoff - :return: (dictionary) - """ - atom_i_segment_id = [] # index list for the center atom i for all bonds (row index) - atom_i_j_id = [] # index list for atom j - atom_number = [] - all_neighbors = structure.get_all_neighbors(r, include_index=True) - distances = [] - state_attributes = state_attributes or [[0, 0]] - for k, n in enumerate(all_neighbors): - atom_number.append(structure.sites[k].specie.Z) - if len(n) < 1: - index = None - else: - _, distance, index = list(zip(*n)) - index = np.array(index) - distance = np.array(distance) - - if index is not None: - ind = np.argsort(index) - it = itemgetter(*ind) - index = it(index) - index = to_list(index) - index = [int(i) for i in index] - distance = distance[ind] - distances.append(distance) - atom_i_segment_id.extend([k] * len(index)) - atom_i_j_id.extend(index) - else: - pass - if len(distances) < 1: - return None - else: - return {'distance': np.concatenate(distances), - 'index1': atom_i_segment_id, - 'index2': atom_i_j_id, - 'node': atom_number, - 'state': state_attributes} - - -def structure2input(structure, r=4, distance_convertor=None, **kwargs): - """ - Take a pymatgen structure and convert it to a index-type graph representation as model input - - :param structure: (pymatgen structure) - :param r: (float) cutoff radius - :param state_attributes: (list) a list of state attributes - :param distance_convertor: (object) convert numeric distance values into a vector as bond features - :return: (dictionary) inputs for model.predict - """ - graph = structure2graph(structure, r=r) - if distance_convertor is None: - centers = kwargs.get('centers', np.linspace(0, 6, 100)) - width = kwargs.get('width', 0.5) - distance_convertor = GaussianDistance(centers, width) - - gnode = [0] * len(structure) - gbond = [0] * len(graph['index1']) - - return [expand_1st(graph['node']), - expand_1st(distance_convertor.convert(graph['distance'])), - expand_1st(np.array(graph['state'])), - expand_1st(np.array(graph['index1'])), - expand_1st(np.array(graph['index2'])), - expand_1st(np.array(gnode)), - expand_1st(np.array(gbond)), - ] - - def graphs2inputs(graphs, targets): """ Expand the graph dictionary to form a list of features and targets diff --git a/megnet/data/graph.py b/megnet/data/graph.py index 4f4452fdc..5d83b459d 100644 --- a/megnet/data/graph.py +++ b/megnet/data/graph.py @@ -51,15 +51,18 @@ def convert(self, d): class GraphBatchGenerator: """ - A generator class that assembles several structures (indicated by batch_size) and form (x, y) pairs for model training + A generator class that assembles several structures (indicated by + batch_size) and form (x, y) pairs for model training :param atom_features: (list of np.array) list of atom feature matrix, :param bond_features: (list of np.array) list of bond features matrix - :param state_features: (list of np.array) list of [1, G] state features, where G is the global state feature dimension - :param index1_list: (list of integer) list of (M, ) one side atomic index of the bond, M is different for different - structures - :param index2_list: (list of integer) list of (M, ) the other side atomic index of the bond, M is different for different - structures, but it has to be the same as the correponding index1. + :param state_features: (list of np.array) list of [1, G] state features, + where G is the global state feature dimension + :param index1_list: (list of integer) list of (M, ) one side atomic index + of the bond, M is different for different structures + :param index2_list: (list of integer) list of (M, ) the other side atomic + index of the bond, M is different for different structures, but it has + to be the same as the correponding index1. :param targets: (numpy array), N*1, where N is the number of structures :param batch_size: (int) number of samples in a batch """ diff --git a/megnet/data/tests/test_crystal.py b/megnet/data/tests/test_crystal.py index d4c56efbf..a827f8abe 100644 --- a/megnet/data/tests/test_crystal.py +++ b/megnet/data/tests/test_crystal.py @@ -1,5 +1,5 @@ import unittest -from megnet.data.crystal import CrystalGraph, structure2graph, structure2input, graphs2inputs +from megnet.data.crystal import CrystalGraph, graphs2inputs from pymatgen import Structure import os @@ -26,12 +26,14 @@ def test_crystalgraph(self): graph3 = cg(self.structures[0]) self.assertListEqual(graph['node'], graph3['node']) - def test_s2graph(self): - graph = structure2graph(self.structures[0]) + def test_convert(self): + cg = CrystalGraph() + graph = cg.convert(self.structures[0]) self.assertListEqual(graph['node'], [i.specie.Z for i in self.structures[0]]) - def test_s2input(self): - inp = structure2input(self.structures[0]) + def test_get_input(self): + cg = CrystalGraph() + inp = cg.get_input(self.structures[0]) self.assertEqual(len(inp), 7) shapes = [i.shape for i in inp] true_shapes = [(1, 28), (1, 704, 100), (1, 1, 2), (1, 704), (1, 704), (1, 28), (1, 704)] @@ -39,8 +41,8 @@ def test_s2input(self): self.assertListEqual(list(i), list(j)) def test_g2inputs(self): - - graphs = [structure2graph(i) for i in self.structures] + cg = CrystalGraph() + graphs = [cg.convert(i) for i in self.structures] targets = [0.1, 0.2] inp = graphs2inputs(graphs, targets) self.assertListEqual([len(i) for i in inp], [2] * 6)