Skip to content
This repository has been archived by the owner on Apr 27, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' of github.com:materialsvirtuallab/megnet
Browse files Browse the repository at this point in the history
  • Loading branch information
chc273 committed Apr 8, 2019
2 parents b55519a + 313ef9e commit 0f3acb4
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 91 deletions.
152 changes: 74 additions & 78 deletions megnet/data/crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,89 +17,85 @@ def __init__(self, r=4):
self.r = r

def convert(self, structure, state_attributes=None):
return structure2graph(structure, state_attributes, r=self.r)
"""
Take a pymatgen structure and convert it to a index-type graph representation
The graph will have node, distance, index1, index2, where node is a vector of Z number
of atoms in the structure, index1 and index2 mark the atom indices forming the bond and separated by
distance
:param structure: (pymatgen structure)
:param state_attributes: (list) state attributes
:return: (dictionary)
"""
atom_i_segment_id = [] # index list for the center atom i for all bonds (row index)
atom_i_j_id = [] # index list for atom j
atom_number = []
all_neighbors = structure.get_all_neighbors(self.r, include_index=True)
distances = []
state_attributes = state_attributes or [[0, 0]]
for k, n in enumerate(all_neighbors):
atom_number.append(structure.sites[k].specie.Z)
if len(n) < 1:
index = None
else:
_, distance, index = list(zip(*n))
index = np.array(index)
distance = np.array(distance)

if index is not None:
ind = np.argsort(index)
it = itemgetter(*ind)
index = it(index)
index = to_list(index)
index = [int(i) for i in index]
distance = distance[ind]
distances.append(distance)
atom_i_segment_id.extend([k] * len(index))
atom_i_j_id.extend(index)
else:
pass
if len(distances) < 1:
return None
else:
return {'distance': np.concatenate(distances),
'index1': atom_i_segment_id,
'index2': atom_i_j_id,
'node': atom_number,
'state': state_attributes}

def get_input(self, structure, distance_convertor=None, **kwargs):
"""
Take a pymatgen structure and convert it to a index-type graph
representation as model input
:param structure: (pymatgen structure)
:param r: (float) cutoff radius
:param state_attributes: (list) a list of state attributes
:param distance_convertor: (object) convert numeric distance values
into a vector as bond features
:return: (dictionary) inputs for model.predict
"""
graph = self.convert(structure)
if distance_convertor is None:
centers = kwargs.get('centers', np.linspace(0, 6, 100))
width = kwargs.get('width', 0.5)
distance_convertor = GaussianDistance(centers, width)

gnode = [0] * len(structure)
gbond = [0] * len(graph['index1'])

return [expand_1st(graph['node']),
expand_1st(distance_convertor.convert(graph['distance'])),
expand_1st(np.array(graph['state'])),
expand_1st(np.array(graph['index1'])),
expand_1st(np.array(graph['index2'])),
expand_1st(np.array(gnode)),
expand_1st(np.array(gbond)),
]

def __call__(self, structure, state_attributes=None):
return self.convert(structure, state_attributes)


def structure2graph(structure, state_attributes=None, r=4):
"""
Take a pymatgen structure and convert it to a index-type graph representation
The graph will have node, distance, index1, index2, where node is a vector of Z number
of atoms in the structure, index1 and index2 mark the atom indices forming the bond and separated by
distance
:param structure: (pymatgen structure)
:param state_attributes: (list) state attributes
:param r: (float) distance cutoff
:return: (dictionary)
"""
atom_i_segment_id = [] # index list for the center atom i for all bonds (row index)
atom_i_j_id = [] # index list for atom j
atom_number = []
all_neighbors = structure.get_all_neighbors(r, include_index=True)
distances = []
state_attributes = state_attributes or [[0, 0]]
for k, n in enumerate(all_neighbors):
atom_number.append(structure.sites[k].specie.Z)
if len(n) < 1:
index = None
else:
_, distance, index = list(zip(*n))
index = np.array(index)
distance = np.array(distance)

if index is not None:
ind = np.argsort(index)
it = itemgetter(*ind)
index = it(index)
index = to_list(index)
index = [int(i) for i in index]
distance = distance[ind]
distances.append(distance)
atom_i_segment_id.extend([k] * len(index))
atom_i_j_id.extend(index)
else:
pass
if len(distances) < 1:
return None
else:
return {'distance': np.concatenate(distances),
'index1': atom_i_segment_id,
'index2': atom_i_j_id,
'node': atom_number,
'state': state_attributes}


def structure2input(structure, r=4, distance_convertor=None, **kwargs):
"""
Take a pymatgen structure and convert it to a index-type graph representation as model input
:param structure: (pymatgen structure)
:param r: (float) cutoff radius
:param state_attributes: (list) a list of state attributes
:param distance_convertor: (object) convert numeric distance values into a vector as bond features
:return: (dictionary) inputs for model.predict
"""
graph = structure2graph(structure, r=r)
if distance_convertor is None:
centers = kwargs.get('centers', np.linspace(0, 6, 100))
width = kwargs.get('width', 0.5)
distance_convertor = GaussianDistance(centers, width)

gnode = [0] * len(structure)
gbond = [0] * len(graph['index1'])

return [expand_1st(graph['node']),
expand_1st(distance_convertor.convert(graph['distance'])),
expand_1st(np.array(graph['state'])),
expand_1st(np.array(graph['index1'])),
expand_1st(np.array(graph['index2'])),
expand_1st(np.array(gnode)),
expand_1st(np.array(gbond)),
]


def graphs2inputs(graphs, targets):
"""
Expand the graph dictionary to form a list of features and targets
Expand Down
15 changes: 9 additions & 6 deletions megnet/data/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,18 @@ def convert(self, d):

class GraphBatchGenerator:
"""
A generator class that assembles several structures (indicated by batch_size) and form (x, y) pairs for model training
A generator class that assembles several structures (indicated by
batch_size) and form (x, y) pairs for model training
:param atom_features: (list of np.array) list of atom feature matrix,
:param bond_features: (list of np.array) list of bond features matrix
:param state_features: (list of np.array) list of [1, G] state features, where G is the global state feature dimension
:param index1_list: (list of integer) list of (M, ) one side atomic index of the bond, M is different for different
structures
:param index2_list: (list of integer) list of (M, ) the other side atomic index of the bond, M is different for different
structures, but it has to be the same as the correponding index1.
:param state_features: (list of np.array) list of [1, G] state features,
where G is the global state feature dimension
:param index1_list: (list of integer) list of (M, ) one side atomic index
of the bond, M is different for different structures
:param index2_list: (list of integer) list of (M, ) the other side atomic
index of the bond, M is different for different structures, but it has
to be the same as the correponding index1.
:param targets: (numpy array), N*1, where N is the number of structures
:param batch_size: (int) number of samples in a batch
"""
Expand Down
16 changes: 9 additions & 7 deletions megnet/data/tests/test_crystal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from megnet.data.crystal import CrystalGraph, structure2graph, structure2input, graphs2inputs
from megnet.data.crystal import CrystalGraph, graphs2inputs
from pymatgen import Structure
import os

Expand All @@ -26,21 +26,23 @@ def test_crystalgraph(self):
graph3 = cg(self.structures[0])
self.assertListEqual(graph['node'], graph3['node'])

def test_s2graph(self):
graph = structure2graph(self.structures[0])
def test_convert(self):
cg = CrystalGraph()
graph = cg.convert(self.structures[0])
self.assertListEqual(graph['node'], [i.specie.Z for i in self.structures[0]])

def test_s2input(self):
inp = structure2input(self.structures[0])
def test_get_input(self):
cg = CrystalGraph()
inp = cg.get_input(self.structures[0])
self.assertEqual(len(inp), 7)
shapes = [i.shape for i in inp]
true_shapes = [(1, 28), (1, 704, 100), (1, 1, 2), (1, 704), (1, 704), (1, 28), (1, 704)]
for i, j in zip(shapes, true_shapes):
self.assertListEqual(list(i), list(j))

def test_g2inputs(self):

graphs = [structure2graph(i) for i in self.structures]
cg = CrystalGraph()
graphs = [cg.convert(i) for i in self.structures]
targets = [0.1, 0.2]
inp = graphs2inputs(graphs, targets)
self.assertListEqual([len(i) for i in inp], [2] * 6)
Expand Down

0 comments on commit 0f3acb4

Please sign in to comment.