Skip to content

Commit

Permalink
updated docstrings and cleaned unused keys
Browse files Browse the repository at this point in the history
  • Loading branch information
BowenD-UCB committed Jun 27, 2023
1 parent 73d3219 commit 31969c5
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 95 deletions.
11 changes: 5 additions & 6 deletions chgnet/graph/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@


class CrystalGraphConverter(nn.Module):
"""Convert a pymatgen.core.Structure to a CrystalGraph.
Only the minimal essential information is kept.
"""Convert a pymatgen.core.Structure to a CrystalGraph
The CrystalGraph dataclass stores essential field to make sure that
gradients like force and stress can be calculated through back-propagation later.
"""

def __init__(
Expand All @@ -31,8 +31,6 @@ def __init__(
atom_graph. Default = 5
bond_graph_cutoff (float): bond length threshold to include bond in bond_graph
Default = 3
verbose (bool): whether to print initialization message
Default = True
"""
super().__init__()
self.atom_graph_cutoff = atom_graph_cutoff
Expand All @@ -57,7 +55,8 @@ def forward(
mp_id (str): Materials Project id of this structure
Default = None
on_isolated_atoms ('ignore' | 'warn' | 'error'): how to handle Structures
with isolated atoms. Default = 'error'
with isolated atoms.
Default = 'error'
Return:
Crystal_Graph that is ready to use by CHGNet
Expand Down
21 changes: 15 additions & 6 deletions chgnet/graph/crystalgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,32 @@ def __init__(
[n_atom, 3]
atom_graph (Tensor): a directed graph adjacency list,
(center atom indices, neighbor atom indices, undirected bond index)
for bonds in bond_fea [2*n_bond, 3]
for bonds in bond_fea
[num_directed_bonds, 2]
atom_graph_cutoff (float): the cutoff radius to draw edges in atom_graph
neighbor_image (Tensor): the periodic image specifying the location of
neighboring atom [2*n_bond, 2]
neighboring atom
see: https://github.com/materialsproject/pymatgen/blob/ca2175c762e37ea7
c9f3950ef249bc540e683da1/pymatgen/core/structure.py#L1485-L1541
[num_directed_bonds, 3]
directed2undirected (Tensor): the mapping from directed edge index to
undirected edge index for the atom graph [2*n_bond]
undirected edge index for the atom graph
[num_directed_bonds]
undirected2directed (Tensor): the mapping from undirected edge index to
one of its directed edge index, this is essentially the inverse
mapping of the directed2undirected this tensor is needed for
computation efficiency. [n_bond]
computation efficiency.
Note that num_directed_bonds = 2 * num_undirected_bonds
[num_undirected_bonds]
bond_graph (Tensor): a directed graph adjacency list,
(atom indices, 1st undirected bond idx, 1st directed bond idx,
2nd undirected bond idx, 2nd directed bond idx)
for angles in angle_fea [n_angle, 5]
for angles in angle_fea
[n_angle, 5]
bond_graph_cutoff (float): the cutoff bond length to include bond
as nodes in bond_graph
lattice (Tensor): lattices of the input structure [3, 3]
lattice (Tensor): lattices of the input structure
[3, 3]
graph_id (str or None): an id to keep track of this crystal graph
Default = None
mp_id (str) or None: Materials Project id of this structure
Expand Down
29 changes: 19 additions & 10 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,15 @@ def __init__(
"""Provide a CHGNet instance to calculate various atomic properties using ASE.
Args:
model (CHGNet): instance of a chgnet model
model (CHGNet): instance of a chgnet model. If set to None,
the pretrained CHGNet is loaded.
Default = None
use_device (str, optional): The device to be used for predictions,
either "cpu", "cuda", or "mps". If not specified, the default device is
automatically selected based on the available options.
Default = None
stress_weight (float): the conversion factor to convert GPa to eV/A^3.
Default = 1/160.21.
Default = 1/160.21
**kwargs: Passed to the Calculator parent class.
"""
super().__init__(**kwargs)
Expand Down Expand Up @@ -136,14 +139,17 @@ def __init__(
"""Provide a trained CHGNet model and an optimizer to relax crystal structures.
Args:
model (CHGNet): instance of a chgnet model
model (CHGNet): instance of a chgnet model. If set to None,
the pretrained CHGNet is loaded.
Default = None
optimizer_class (Optimizer,str): choose optimizer from ASE.
Default = FIRE
Default = "FIRE"
use_device (str, optional): The device to be used for predictions,
either "cpu", "cuda", or "mps". If not specified, the default device is
automatically selected based on the available options.
Default = None
stress_weight (float): the conversion factor to convert GPa to eV/A^3.
Default = 1/160.21.
Default = 1/160.21
"""
if isinstance(optimizer_class, str):
if optimizer_class in OPTIMIZERS:
Expand Down Expand Up @@ -188,8 +194,8 @@ def relax(
**kwargs: Additional parameters for the optimizer.
Returns:
dict[str, Structure | TrajectoryObserver]: A dictionary with keys 'final_structure'
and 'trajectory'.
dict[str, Structure | TrajectoryObserver]:
A dictionary with 'final_structure' and 'trajectory'.
"""
if isinstance(atoms, Structure):
atoms = AseAtomsAdaptor.get_atoms(atoms)
Expand Down Expand Up @@ -453,14 +459,17 @@ def __init__(
"""Initialize a structure optimizer object for calculation of bulk modulus.
Args:
model (CHGNet): instance of a chgnet model
model (CHGNet): instance of a chgnet model. If set to None,
the pretrained CHGNet is loaded.
Default = None
optimizer_class (Optimizer,str): choose optimizer from ASE.
Default = FIRE
Default = "FIRE"
use_device (str, optional): The device to be used for predictions,
either "cpu", "cuda", or "mps". If not specified, the default device is
automatically selected based on the available options.
Default = None
stress_weight (float): the conversion factor to convert GPa to eV/A^3.
Default = 1/160.21.
Default = 1/160.21
"""
self.relaxer = StructOptimizer(
model=model,
Expand Down
150 changes: 80 additions & 70 deletions chgnet/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,33 @@ def __init__(
norm: str | None = None,
use_mlp_out: bool = True,
resnet: bool = True,
gMLP_norm: str = "batch",
gMLP_norm: str | None = None,
) -> None:
"""Args:
atom_fea_dim (int): The dimensionality of the input atom features.
bond_fea_dim (int): The dimensionality of the input bond features.
hidden_dim (int, optional): The dimensionality of the hidden layers in the
gated MLP.
Default = 64.
dropout (float, optional): The dropout probability to apply to the gated MLP.
Default = 0.
activation (str, optional): The name of the activation function to use in the
gated MLP.
Must be one of "relu", "silu", "tanh", or "gelu". Default = "silu".
norm (str, optional): The name of the normalization layer to use on the updated
atom features. Must be one of "batch", "layer", or None.
Default = None.
use_mlp_out (bool, optional): Whether to apply an MLP output layer to the
updated atom features.
Default = True.
resnet (bool, optional): Whether to apply a residual connection to the
updated atom features.
Default = True.
gMLP_norm (str, optional): The name of the normalization layer to use on the
gated MLP. Must be one of "batch", "layer", or None. Default = "batch".
**kwargs: Additional keyword arguments to pass to the normalization layer.
"""Initialize the AtomConv layer.
Args:
atom_fea_dim (int): The dimensionality of the input atom features.
bond_fea_dim (int): The dimensionality of the input bond features.
hidden_dim (int, optional): The dimensionality of the hidden layers in the
gated MLP.
Default = 64
dropout (float, optional): The dropout rate to apply to the gated MLP.
Default = 0.
activation (str, optional): The name of the activation function to use in
the gated MLP. Must be one of "relu", "silu", "tanh", or "gelu".
Default = "silu"
norm (str, optional): The name of the normalization layer to use on the
updated atom features. Must be one of "batch", "layer", or None.
Default = None
use_mlp_out (bool, optional): Whether to apply an MLP output layer to the
updated atom features.
Default = True
resnet (bool, optional): Whether to apply a residual connection to the
updated atom features.
Default = True
gMLP_norm (str, optional): The name of the normalization layer to use on the
gated MLP. Must be one of "batch", "layer", or None.
Default = None
"""
super().__init__()
self.use_mlp_out = use_mlp_out
Expand Down Expand Up @@ -140,31 +142,34 @@ def __init__(
norm: str | None = None,
use_mlp_out: bool = True,
resnet=True,
**kwargs,
gMLP_norm: str | None = None,
) -> None:
"""Args:
atom_fea_dim (int): The dimensionality of the input atom features.
bond_fea_dim (int): The dimensionality of the input bond features.
angle_fea_dim (int): The dimensionality of the input angle features.
hidden_dim (int, optional): The dimensionality of the hidden layers
in the gated MLP.
Default = 64.
dropout (float, optional): The dropout probability to apply to the gated MLP.
Default = 0.
activation (str, optional): The name of the activation function to use
in the gated MLP.
Must be one of "relu", "silu", "tanh", or "gelu". Default = "silu".
norm (str, optional): The name of the normalization layer to use on the
updated atom features.
Must be one of "batch", "layer", or None.
Default = None.
use_mlp_out (bool, optional): Whether to apply an MLP output layer to the
updated atom features.
Default = True.
resnet (bool, optional): Whether to apply a residual connection to the
updated atom features.
Default = True.
**kwargs: Additional keyword arguments to pass to the normalization layer.
"""Initialize the BondConv layer.
Args:
atom_fea_dim (int): The dimensionality of the input atom features.
bond_fea_dim (int): The dimensionality of the input bond features.
angle_fea_dim (int): The dimensionality of the input angle features.
hidden_dim (int, optional): The dimensionality of the hidden layers
in the gated MLP.
Default = 64
dropout (float, optional): The dropout rate to apply to the gated MLP.
Default = 0.
activation (str, optional): The name of the activation function to use
in the gated MLP. Must be one of "relu", "silu", "tanh", or "gelu".
Default = "silu"
norm (str, optional): The name of the normalization layer to use on the
updated atom features. Must be one of "batch", "layer", or None.
Default = None
use_mlp_out (bool, optional): Whether to apply an MLP output layer to the
updated atom features.
Default = True
resnet (bool, optional): Whether to apply a residual connection to the
updated atom features.
Default = True
gMLP_norm (str, optional): The name of the normalization layer to use on the
gated MLP. Must be one of "batch", "layer", or None.
Default = None
"""
super().__init__()
self.use_mlp_out = use_mlp_out
Expand All @@ -175,7 +180,7 @@ def __init__(
output_dim=bond_fea_dim,
hidden_dim=hidden_dim,
dropout=dropout,
norm=kwargs.pop("gMLP_norm", "batch"),
norm=gMLP_norm,
activation=activation,
)
if self.use_mlp_out:
Expand All @@ -202,13 +207,13 @@ def forward(
bond_weights (Tensor): BondGraph bond weights with shape
[num_undirected_bonds, bond_fea_dim]
angle_feas (Tensor): angle features tensor with shape
[num_batch_angles, atom_fea_dim]
[num_batch_angles, angle_fea_dim]
bond_graph (Tensor): Directed BondGraph tensor with shape
[num_batched_angles, 3]
Returns:
new_bond_feas (Tensor): bond feature tensor with shape
[num_batch_atom, bond_fea_dim]
[num_undirected_bonds, bond_fea_dim]
Notes:
- num_batch_atoms = sum(num_atoms) in batch
Expand Down Expand Up @@ -257,30 +262,31 @@ def __init__(
activation: str = "silu",
norm: str | None = None,
resnet: bool = True,
**kwargs,
gMLP_norm: str | None = None,
) -> None:
"""Create a new AngleUpdate instance.
"""Initialize the AngleUpdate layer.
Args:
atom_fea_dim (int): The dimensionality of the input atom features.
bond_fea_dim (int): The dimensionality of the input bond features.
angle_fea_dim (int): The dimensionality of the input angle features.
hidden_dim (int, optional): The dimensionality of the hidden layers
in the gated MLP.
Default = 0.
dropout (float, optional): The dropout probability to apply to the gated MLP.
Default = 0
dropout (float, optional): The dropout rate to apply to the gated MLP.
Default = 0.
activation (str, optional): The name of the activation function to use
in the gated MLP. Must be one of "relu", "silu", "tanh", or "gelu".
Default = "silu".
Default = "silu"
norm (str, optional): The name of the normalization layer to use on the
updated atom features.
Must be one of "batch", "layer", or None.
Default = None.
updated atom features. Must be one of "batch", "layer", or None.
Default = None
resnet (bool, optional): Whether to apply a residual connection to the
updated atom features.
Default = True.
**kwargs: Additional keyword arguments to pass to the normalization layer.
Default = True
gMLP_norm (str, optional): The name of the normalization layer to use on the
gated MLP. Must be one of "batch", "layer", or None.
Default = None
"""
super().__init__()
self.resnet = resnet
Expand All @@ -290,7 +296,7 @@ def __init__(
output_dim=angle_fea_dim,
hidden_dim=hidden_dim,
dropout=dropout,
norm=kwargs.pop("gMLP_norm", "batch"),
norm=gMLP_norm,
activation=activation,
)
self.angle_norm = find_normalization(norm, dim=angle_fea_dim)
Expand All @@ -310,13 +316,13 @@ def forward(
bond_feas (Tensor): bond features tensor with shape
[num_undirected_bonds, bond_fea_dim]
angle_feas (Tensor): angle features tensor with shape
[num_batch_angles, atom_fea_dim]
[num_batch_angles, angle_fea_dim]
bond_graph (Tensor): Directed BondGraph tensor with shape
[num_batched_angles, 3]
Returns:
new_angle_feas (Tensor): angle features tensor with shape
[num_batch_angles, atom_fea_dim]
[num_batch_angles, angle_fea_dim]
Notes:
- num_batch_atoms = sum(num_atoms) in batch
Expand Down Expand Up @@ -355,11 +361,13 @@ def forward(self, atom_feas: Tensor, atom_owner: Tensor) -> Tensor:
Args:
atom_feas (Tensor): batched atom features after convolution layers.
shape = [num_batch_atoms, atom_fea_dim]
atom_owner (Tensor): graph indices for each atom. shape = [num_batch_atoms]
[num_batch_atoms, atom_fea_dim or 1]
atom_owner (Tensor): graph indices for each atom.
[num_batch_atoms]
Returns:
crystal_feas (Tensor): crystal feature matrix. shape = [n_crystals, atom_fea_dim]
crystal_feas (Tensor): crystal feature matrix.
[n_crystals, atom_fea_dim or 1]
"""
return aggregate(atom_feas, atom_owner, average=self.average)

Expand Down Expand Up @@ -392,11 +400,13 @@ def forward(self, atom_feas: Tensor, atom_owner: Tensor) -> Tensor:
Args:
atom_feas (Tensor): batched atom features after convolution layers.
shape = [num_batch_atoms, atom_fea_dim]
atom_owner (Tensor): graph indices for each atom. shape = [num_batch_atoms]
[num_batch_atoms, atom_fea_dim]
atom_owner (Tensor): graph indices for each atom.
[num_batch_atoms]
Returns:
crystal_feas (Tensor): crystal feature matrix. shape = [n_crystals, atom_fea_dim]
crystal_feas (Tensor): crystal feature matrix.
[n_crystals, atom_fea_dim]
"""
crystal_feas = []
weights = self.key(atom_feas) # [n_batch_atom, n_heads]
Expand Down
Loading

0 comments on commit 31969c5

Please sign in to comment.