Skip to content
This repository has been archived by the owner on Apr 27, 2023. It is now read-only.

Commit

Permalink
apply black
Browse files Browse the repository at this point in the history
  • Loading branch information
Chi Chen committed Jan 23, 2022
1 parent 22968cf commit 666114c
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 81 deletions.
126 changes: 54 additions & 72 deletions megnet/data/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ class StructureGraph(MSONable):
"""

def __init__(
self,
nn_strategy: Union[str, NearNeighbors] = None,
atom_converter: Converter = None,
bond_converter: Converter = None,
**kwargs,
self,
nn_strategy: Union[str, NearNeighbors] = None,
atom_converter: Converter = None,
bond_converter: Converter = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -83,8 +83,7 @@ def __init__(
self.atom_converter = atom_converter or self._get_dummy_converter()
self.bond_converter = bond_converter or self._get_dummy_converter()

def convert(self, structure: Structure,
state_attributes: List = None) -> Dict:
def convert(self, structure: Structure, state_attributes: List = None) -> Dict:
"""
Take a pymatgen structure and convert it to a index-type graph representation
The graph will have node, distance, index1, index2, where node is a vector of Z number
Expand All @@ -98,17 +97,14 @@ def convert(self, structure: Structure,
(dictionary)
"""
state_attributes = (
state_attributes or getattr(structure, "state",
None) or np.array([[0.0, 0.0]],
dtype="float32")
state_attributes or getattr(structure, "state", None) or np.array([[0.0, 0.0]], dtype="float32")
)
index1 = []
index2 = []
bonds = []
if self.nn_strategy is None:
raise RuntimeError("NearNeighbor strategy is not provided!")
for n, neighbors in enumerate(
self.nn_strategy.get_all_nn_info(structure)):
for n, neighbors in enumerate(self.nn_strategy.get_all_nn_info(structure)):
index1.extend([n] * len(neighbors))
for neighbor in neighbors:
index2.append(neighbor["site_index"])
Expand All @@ -117,8 +113,7 @@ def convert(self, structure: Structure,
if np.size(np.unique(index1)) < len(atoms):
logger.warning("Isolated atoms found in the structure")

return {"atom": atoms, "bond": bonds, "state": state_attributes,
"index1": index1, "index2": index2}
return {"atom": atoms, "bond": bonds, "state": state_attributes, "index1": index1, "index2": index2}

@staticmethod
def get_atom_features(structure) -> List[Any]:
Expand All @@ -129,8 +124,7 @@ def get_atom_features(structure) -> List[Any]:
Returns:
List of atomic numbers
"""
return np.array([i.specie.Z for i in structure],
dtype="int32").tolist()
return np.array([i.specie.Z for i in structure], dtype="int32").tolist()

def __call__(self, structure: Structure) -> Dict:
"""
Expand Down Expand Up @@ -235,8 +229,7 @@ class StructureGraphFixedRadius(StructureGraph):
pymatgen. It is orders of magnitude faster than previous implementations
"""

def convert(self, structure: Structure,
state_attributes: List = None) -> Dict:
def convert(self, structure: Structure, state_attributes: List = None) -> Dict:
"""
Take a pymatgen structure and convert it to a index-type graph representation
The graph will have node, distance, index1, index2, where node is a vector of Z number
Expand All @@ -250,28 +243,21 @@ def convert(self, structure: Structure,
(dictionary)
"""
state_attributes = (
state_attributes or getattr(structure, "state",
None) or np.array([[0.0, 0.0]],
dtype="float32")
state_attributes or getattr(structure, "state", None) or np.array([[0.0, 0.0]], dtype="float32")
)
atoms = self.get_atom_features(structure)
index1, index2, _, bonds = get_graphs_within_cutoff(structure,
self.nn_strategy.cutoff)
index1, index2, _, bonds = get_graphs_within_cutoff(structure, self.nn_strategy.cutoff)

if len(index1) == 0:
raise RuntimeError("The cutoff is too small, resulting in "
"material graph with no bonds")
raise RuntimeError("The cutoff is too small, resulting in " "material graph with no bonds")

if np.size(np.unique(index1)) < len(atoms):
logger.warning("Isolated atoms found in the structure. The "
"cutoff radius might be small")
logger.warning("Isolated atoms found in the structure. The " "cutoff radius might be small")

return {"atom": atoms, "bond": bonds, "state": state_attributes,
"index1": index1, "index2": index2}
return {"atom": atoms, "bond": bonds, "state": state_attributes, "index1": index1, "index2": index2}

@classmethod
def from_structure_graph(cls,
structure_graph: StructureGraph) -> "StructureGraphFixedRadius":
def from_structure_graph(cls, structure_graph: StructureGraph) -> "StructureGraphFixedRadius":
"""
Initialize from pymatgen StructureGraph
Args:
Expand Down Expand Up @@ -332,8 +318,7 @@ class GaussianDistance(Converter):
Expand distance with Gaussian basis sit at centers and with width 0.5.
"""

def __init__(self, centers: np.ndarray = np.linspace(0, 5, 100),
width=0.5):
def __init__(self, centers: np.ndarray = np.linspace(0, 5, 100), width=0.5):
"""
Args:
Expand All @@ -352,8 +337,7 @@ def convert(self, d: np.ndarray) -> np.ndarray:
(matrix) N*M matrix with N the length of d and M the length of centers
"""
d = np.array(d)
return np.exp(
-((d[:, None] - self.centers[None, :]) ** 2) / self.width ** 2)
return np.exp(-((d[:, None] - self.centers[None, :]) ** 2) / self.width ** 2)


class BaseGraphBatchGenerator(Sequence):
Expand All @@ -366,12 +350,12 @@ class BaseGraphBatchGenerator(Sequence):
"""

def __init__(
self,
dataset_size: int,
targets: np.ndarray,
sample_weights: np.ndarray = None,
batch_size: int = 128,
is_shuffle: bool = True,
self,
dataset_size: int,
targets: np.ndarray,
sample_weights: np.ndarray = None,
batch_size: int = 128,
is_shuffle: bool = True,
):
"""
Args:
Expand Down Expand Up @@ -403,12 +387,12 @@ def __len__(self) -> int:
return self.max_step

def _combine_graph_data(
self,
feature_list_temp: List[np.ndarray],
connection_list_temp: List[np.ndarray],
global_list_temp: List[np.ndarray],
index1_temp: List[np.ndarray],
index2_temp: List[np.ndarray],
self,
feature_list_temp: List[np.ndarray],
connection_list_temp: List[np.ndarray],
global_list_temp: List[np.ndarray],
index1_temp: List[np.ndarray],
index2_temp: List[np.ndarray],
) -> tuple:
"""Compile the matrices describing each graph into single matrices for the entire graph
Beyond concatenating the graph descriptions, this operation updates the indices of each
Expand Down Expand Up @@ -516,8 +500,7 @@ def process_state_feature(self, x: np.ndarray) -> np.ndarray:

def __getitem__(self, index: int) -> tuple:
# Get the indices for this batch
batch_index = self.mol_index[
index * self.batch_size: (index + 1) * self.batch_size]
batch_index = self.mol_index[index * self.batch_size : (index + 1) * self.batch_size]

# Get the inputs for each batch
inputs = self._generate_inputs(batch_index)
Expand Down Expand Up @@ -560,16 +543,16 @@ class GraphBatchGenerator(BaseGraphBatchGenerator):
"""

def __init__(
self,
atom_features: List[np.ndarray],
bond_features: List[np.ndarray],
state_features: List[np.ndarray],
index1_list: List[int],
index2_list: List[int],
targets: np.ndarray = None,
sample_weights: np.ndarray = None,
batch_size: int = 128,
is_shuffle: bool = True,
self,
atom_features: List[np.ndarray],
bond_features: List[np.ndarray],
state_features: List[np.ndarray],
index1_list: List[int],
index2_list: List[int],
targets: np.ndarray = None,
sample_weights: np.ndarray = None,
batch_size: int = 128,
is_shuffle: bool = True,
):
"""
Args:
Expand All @@ -587,8 +570,7 @@ def __init__(
batch_size: (int) number of samples in a batch
"""
super().__init__(
len(atom_features), targets, sample_weights=sample_weights,
batch_size=batch_size, is_shuffle=is_shuffle
len(atom_features), targets, sample_weights=sample_weights, batch_size=batch_size, is_shuffle=is_shuffle
)
self.atom_features = atom_features
self.bond_features = bond_features
Expand Down Expand Up @@ -625,17 +607,17 @@ class GraphBatchDistanceConvert(GraphBatchGenerator):
"""

def __init__(
self,
atom_features: List[np.ndarray],
bond_features: List[np.ndarray],
state_features: List[np.ndarray],
index1_list: List[int],
index2_list: List[int],
targets: np.ndarray = None,
sample_weights: np.ndarray = None,
batch_size: int = 128,
is_shuffle: bool = True,
distance_converter: Converter = None,
self,
atom_features: List[np.ndarray],
bond_features: List[np.ndarray],
state_features: List[np.ndarray],
index1_list: List[int],
index2_list: List[int],
targets: np.ndarray = None,
sample_weights: np.ndarray = None,
batch_size: int = 128,
is_shuffle: bool = True,
distance_converter: Converter = None,
):
"""
Expand Down
7 changes: 2 additions & 5 deletions megnet/layers/graph/megnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,7 @@ def rho_e_v(self, e_p, inputs):
"""
node, edges, u, index1, index2, gnode, gbond = inputs
index1 = tf.reshape(index1, (-1,))
return tf.expand_dims(self.unsorted_seg_method(
tf.squeeze(e_p), index1, num_segments=tf.shape(node)[1]),
axis=0)
return tf.expand_dims(self.unsorted_seg_method(tf.squeeze(e_p), index1, num_segments=tf.shape(node)[1]), axis=0)

def phi_v(self, b_ei_p, inputs):
"""
Expand Down Expand Up @@ -268,8 +266,7 @@ def rho_e_u(self, e_p, inputs):
"""
nodes, edges, u, index1, index2, gnode, gbond = inputs
gbond = tf.reshape(gbond, (-1,))
return tf.expand_dims(self.seg_method(tf.squeeze(e_p), gbond),
axis=0)
return tf.expand_dims(self.seg_method(tf.squeeze(e_p), gbond), axis=0)

def rho_v_u(self, v_p, inputs):
"""
Expand Down
3 changes: 1 addition & 2 deletions megnet/layers/readout/set2set.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,7 @@ def call(self, inputs, mask=None):
for i in range(self.T):
self.h, c = self._lstm(q_star, self.c)
e_i_t = tf.reduce_sum(input_tensor=m * tf.repeat(self.h, repeats=counts, axis=1), axis=-1)
maxes = tf.math.segment_max(e_i_t[0],
feature_graph_index)
maxes = tf.math.segment_max(e_i_t[0], feature_graph_index)
maxes = tf.repeat(maxes, repeats=counts)
e_i_t -= tf.expand_dims(maxes, axis=0)
# e_i_t -= tf.expand_dims(tf.gather(maxes, feature_graph_index,
Expand Down
3 changes: 1 addition & 2 deletions megnet/utils/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ def from_training_data(
return cls(mean, std, is_intensive)

def __str__(self):
return f"StandardScaler(mean={self.mean:.3f}, std={self.std:.3f}, " \
f"is_intensive={self.is_intensive})"
return f"StandardScaler(mean={self.mean:.3f}, std={self.std:.3f}, " f"is_intensive={self.is_intensive})"

def __repr__(self):
return str(self)
Expand Down

0 comments on commit 666114c

Please sign in to comment.