diff --git a/chgnet/graph/converter.py b/chgnet/graph/converter.py index 7cb9d225..e548574c 100644 --- a/chgnet/graph/converter.py +++ b/chgnet/graph/converter.py @@ -16,9 +16,9 @@ class CrystalGraphConverter(nn.Module): - """Convert a pymatgen.core.Structure to a CrystalGraph. - - Only the minimal essential information is kept. + """Convert a pymatgen.core.Structure to a CrystalGraph + The CrystalGraph dataclass stores essential field to make sure that + gradients like force and stress can be calculated through back-propagation later. """ def __init__( @@ -31,8 +31,6 @@ def __init__( atom_graph. Default = 5 bond_graph_cutoff (float): bond length threshold to include bond in bond_graph Default = 3 - verbose (bool): whether to print initialization message - Default = True """ super().__init__() self.atom_graph_cutoff = atom_graph_cutoff @@ -57,7 +55,8 @@ def forward( mp_id (str): Materials Project id of this structure Default = None on_isolated_atoms ('ignore' | 'warn' | 'error'): how to handle Structures - with isolated atoms. Default = 'error' + with isolated atoms. + Default = 'error' Return: Crystal_Graph that is ready to use by CHGNet diff --git a/chgnet/graph/crystalgraph.py b/chgnet/graph/crystalgraph.py index 2411d3ab..dc0e21c6 100644 --- a/chgnet/graph/crystalgraph.py +++ b/chgnet/graph/crystalgraph.py @@ -38,23 +38,32 @@ def __init__( [n_atom, 3] atom_graph (Tensor): a directed graph adjacency list, (center atom indices, neighbor atom indices, undirected bond index) - for bonds in bond_fea [2*n_bond, 3] + for bonds in bond_fea + [num_directed_bonds, 2] atom_graph_cutoff (float): the cutoff radius to draw edges in atom_graph neighbor_image (Tensor): the periodic image specifying the location of - neighboring atom [2*n_bond, 2] + neighboring atom + see: https://github.com/materialsproject/pymatgen/blob/ca2175c762e37ea7 + c9f3950ef249bc540e683da1/pymatgen/core/structure.py#L1485-L1541 + [num_directed_bonds, 3] directed2undirected (Tensor): the mapping from directed edge index to - undirected edge index for the atom graph [2*n_bond] + undirected edge index for the atom graph + [num_directed_bonds] undirected2directed (Tensor): the mapping from undirected edge index to one of its directed edge index, this is essentially the inverse mapping of the directed2undirected this tensor is needed for - computation efficiency. [n_bond] + computation efficiency. + Note that num_directed_bonds = 2 * num_undirected_bonds + [num_undirected_bonds] bond_graph (Tensor): a directed graph adjacency list, (atom indices, 1st undirected bond idx, 1st directed bond idx, 2nd undirected bond idx, 2nd directed bond idx) - for angles in angle_fea [n_angle, 5] + for angles in angle_fea + [n_angle, 5] bond_graph_cutoff (float): the cutoff bond length to include bond as nodes in bond_graph - lattice (Tensor): lattices of the input structure [3, 3] + lattice (Tensor): lattices of the input structure + [3, 3] graph_id (str or None): an id to keep track of this crystal graph Default = None mp_id (str) or None: Materials Project id of this structure diff --git a/chgnet/model/dynamics.py b/chgnet/model/dynamics.py index 54f5375e..8415029d 100644 --- a/chgnet/model/dynamics.py +++ b/chgnet/model/dynamics.py @@ -59,12 +59,15 @@ def __init__( """Provide a CHGNet instance to calculate various atomic properties using ASE. Args: - model (CHGNet): instance of a chgnet model + model (CHGNet): instance of a chgnet model. If set to None, + the pretrained CHGNet is loaded. + Default = None use_device (str, optional): The device to be used for predictions, either "cpu", "cuda", or "mps". If not specified, the default device is automatically selected based on the available options. + Default = None stress_weight (float): the conversion factor to convert GPa to eV/A^3. - Default = 1/160.21. + Default = 1/160.21 **kwargs: Passed to the Calculator parent class. """ super().__init__(**kwargs) @@ -136,14 +139,17 @@ def __init__( """Provide a trained CHGNet model and an optimizer to relax crystal structures. Args: - model (CHGNet): instance of a chgnet model + model (CHGNet): instance of a chgnet model. If set to None, + the pretrained CHGNet is loaded. + Default = None optimizer_class (Optimizer,str): choose optimizer from ASE. - Default = FIRE + Default = "FIRE" use_device (str, optional): The device to be used for predictions, either "cpu", "cuda", or "mps". If not specified, the default device is automatically selected based on the available options. + Default = None stress_weight (float): the conversion factor to convert GPa to eV/A^3. - Default = 1/160.21. + Default = 1/160.21 """ if isinstance(optimizer_class, str): if optimizer_class in OPTIMIZERS: @@ -188,8 +194,8 @@ def relax( **kwargs: Additional parameters for the optimizer. Returns: - dict[str, Structure | TrajectoryObserver]: A dictionary with keys 'final_structure' - and 'trajectory'. + dict[str, Structure | TrajectoryObserver]: + A dictionary with 'final_structure' and 'trajectory'. """ if isinstance(atoms, Structure): atoms = AseAtomsAdaptor.get_atoms(atoms) @@ -453,14 +459,17 @@ def __init__( """Initialize a structure optimizer object for calculation of bulk modulus. Args: - model (CHGNet): instance of a chgnet model + model (CHGNet): instance of a chgnet model. If set to None, + the pretrained CHGNet is loaded. + Default = None optimizer_class (Optimizer,str): choose optimizer from ASE. - Default = FIRE + Default = "FIRE" use_device (str, optional): The device to be used for predictions, either "cpu", "cuda", or "mps". If not specified, the default device is automatically selected based on the available options. + Default = None stress_weight (float): the conversion factor to convert GPa to eV/A^3. - Default = 1/160.21. + Default = 1/160.21 """ self.relaxer = StructOptimizer( model=model, diff --git a/chgnet/model/layers.py b/chgnet/model/layers.py index f116d91d..f1c40ab1 100644 --- a/chgnet/model/layers.py +++ b/chgnet/model/layers.py @@ -25,31 +25,33 @@ def __init__( norm: str | None = None, use_mlp_out: bool = True, resnet: bool = True, - gMLP_norm: str = "batch", + gMLP_norm: str | None = None, ) -> None: - """Args: - atom_fea_dim (int): The dimensionality of the input atom features. - bond_fea_dim (int): The dimensionality of the input bond features. - hidden_dim (int, optional): The dimensionality of the hidden layers in the - gated MLP. - Default = 64. - dropout (float, optional): The dropout probability to apply to the gated MLP. - Default = 0. - activation (str, optional): The name of the activation function to use in the - gated MLP. - Must be one of "relu", "silu", "tanh", or "gelu". Default = "silu". - norm (str, optional): The name of the normalization layer to use on the updated - atom features. Must be one of "batch", "layer", or None. - Default = None. - use_mlp_out (bool, optional): Whether to apply an MLP output layer to the - updated atom features. - Default = True. - resnet (bool, optional): Whether to apply a residual connection to the - updated atom features. - Default = True. - gMLP_norm (str, optional): The name of the normalization layer to use on the - gated MLP. Must be one of "batch", "layer", or None. Default = "batch". - **kwargs: Additional keyword arguments to pass to the normalization layer. + """Initialize the AtomConv layer. + + Args: + atom_fea_dim (int): The dimensionality of the input atom features. + bond_fea_dim (int): The dimensionality of the input bond features. + hidden_dim (int, optional): The dimensionality of the hidden layers in the + gated MLP. + Default = 64 + dropout (float, optional): The dropout rate to apply to the gated MLP. + Default = 0. + activation (str, optional): The name of the activation function to use in + the gated MLP. Must be one of "relu", "silu", "tanh", or "gelu". + Default = "silu" + norm (str, optional): The name of the normalization layer to use on the + updated atom features. Must be one of "batch", "layer", or None. + Default = None + use_mlp_out (bool, optional): Whether to apply an MLP output layer to the + updated atom features. + Default = True + resnet (bool, optional): Whether to apply a residual connection to the + updated atom features. + Default = True + gMLP_norm (str, optional): The name of the normalization layer to use on the + gated MLP. Must be one of "batch", "layer", or None. + Default = None """ super().__init__() self.use_mlp_out = use_mlp_out @@ -140,31 +142,34 @@ def __init__( norm: str | None = None, use_mlp_out: bool = True, resnet=True, - **kwargs, + gMLP_norm: str | None = None, ) -> None: - """Args: - atom_fea_dim (int): The dimensionality of the input atom features. - bond_fea_dim (int): The dimensionality of the input bond features. - angle_fea_dim (int): The dimensionality of the input angle features. - hidden_dim (int, optional): The dimensionality of the hidden layers - in the gated MLP. - Default = 64. - dropout (float, optional): The dropout probability to apply to the gated MLP. - Default = 0. - activation (str, optional): The name of the activation function to use - in the gated MLP. - Must be one of "relu", "silu", "tanh", or "gelu". Default = "silu". - norm (str, optional): The name of the normalization layer to use on the - updated atom features. - Must be one of "batch", "layer", or None. - Default = None. - use_mlp_out (bool, optional): Whether to apply an MLP output layer to the - updated atom features. - Default = True. - resnet (bool, optional): Whether to apply a residual connection to the - updated atom features. - Default = True. - **kwargs: Additional keyword arguments to pass to the normalization layer. + """Initialize the BondConv layer. + + Args: + atom_fea_dim (int): The dimensionality of the input atom features. + bond_fea_dim (int): The dimensionality of the input bond features. + angle_fea_dim (int): The dimensionality of the input angle features. + hidden_dim (int, optional): The dimensionality of the hidden layers + in the gated MLP. + Default = 64 + dropout (float, optional): The dropout rate to apply to the gated MLP. + Default = 0. + activation (str, optional): The name of the activation function to use + in the gated MLP. Must be one of "relu", "silu", "tanh", or "gelu". + Default = "silu" + norm (str, optional): The name of the normalization layer to use on the + updated atom features. Must be one of "batch", "layer", or None. + Default = None + use_mlp_out (bool, optional): Whether to apply an MLP output layer to the + updated atom features. + Default = True + resnet (bool, optional): Whether to apply a residual connection to the + updated atom features. + Default = True + gMLP_norm (str, optional): The name of the normalization layer to use on the + gated MLP. Must be one of "batch", "layer", or None. + Default = None """ super().__init__() self.use_mlp_out = use_mlp_out @@ -175,7 +180,7 @@ def __init__( output_dim=bond_fea_dim, hidden_dim=hidden_dim, dropout=dropout, - norm=kwargs.pop("gMLP_norm", "batch"), + norm=gMLP_norm, activation=activation, ) if self.use_mlp_out: @@ -202,13 +207,13 @@ def forward( bond_weights (Tensor): BondGraph bond weights with shape [num_undirected_bonds, bond_fea_dim] angle_feas (Tensor): angle features tensor with shape - [num_batch_angles, atom_fea_dim] + [num_batch_angles, angle_fea_dim] bond_graph (Tensor): Directed BondGraph tensor with shape [num_batched_angles, 3] Returns: new_bond_feas (Tensor): bond feature tensor with shape - [num_batch_atom, bond_fea_dim] + [num_undirected_bonds, bond_fea_dim] Notes: - num_batch_atoms = sum(num_atoms) in batch @@ -257,9 +262,9 @@ def __init__( activation: str = "silu", norm: str | None = None, resnet: bool = True, - **kwargs, + gMLP_norm: str | None = None, ) -> None: - """Create a new AngleUpdate instance. + """Initialize the AngleUpdate layer. Args: atom_fea_dim (int): The dimensionality of the input atom features. @@ -267,20 +272,21 @@ def __init__( angle_fea_dim (int): The dimensionality of the input angle features. hidden_dim (int, optional): The dimensionality of the hidden layers in the gated MLP. - Default = 0. - dropout (float, optional): The dropout probability to apply to the gated MLP. + Default = 0 + dropout (float, optional): The dropout rate to apply to the gated MLP. Default = 0. activation (str, optional): The name of the activation function to use in the gated MLP. Must be one of "relu", "silu", "tanh", or "gelu". - Default = "silu". + Default = "silu" norm (str, optional): The name of the normalization layer to use on the - updated atom features. - Must be one of "batch", "layer", or None. - Default = None. + updated atom features. Must be one of "batch", "layer", or None. + Default = None resnet (bool, optional): Whether to apply a residual connection to the updated atom features. - Default = True. - **kwargs: Additional keyword arguments to pass to the normalization layer. + Default = True + gMLP_norm (str, optional): The name of the normalization layer to use on the + gated MLP. Must be one of "batch", "layer", or None. + Default = None """ super().__init__() self.resnet = resnet @@ -290,7 +296,7 @@ def __init__( output_dim=angle_fea_dim, hidden_dim=hidden_dim, dropout=dropout, - norm=kwargs.pop("gMLP_norm", "batch"), + norm=gMLP_norm, activation=activation, ) self.angle_norm = find_normalization(norm, dim=angle_fea_dim) @@ -310,13 +316,13 @@ def forward( bond_feas (Tensor): bond features tensor with shape [num_undirected_bonds, bond_fea_dim] angle_feas (Tensor): angle features tensor with shape - [num_batch_angles, atom_fea_dim] + [num_batch_angles, angle_fea_dim] bond_graph (Tensor): Directed BondGraph tensor with shape [num_batched_angles, 3] Returns: new_angle_feas (Tensor): angle features tensor with shape - [num_batch_angles, atom_fea_dim] + [num_batch_angles, angle_fea_dim] Notes: - num_batch_atoms = sum(num_atoms) in batch @@ -355,11 +361,13 @@ def forward(self, atom_feas: Tensor, atom_owner: Tensor) -> Tensor: Args: atom_feas (Tensor): batched atom features after convolution layers. - shape = [num_batch_atoms, atom_fea_dim] - atom_owner (Tensor): graph indices for each atom. shape = [num_batch_atoms] + [num_batch_atoms, atom_fea_dim or 1] + atom_owner (Tensor): graph indices for each atom. + [num_batch_atoms] Returns: - crystal_feas (Tensor): crystal feature matrix. shape = [n_crystals, atom_fea_dim] + crystal_feas (Tensor): crystal feature matrix. + [n_crystals, atom_fea_dim or 1] """ return aggregate(atom_feas, atom_owner, average=self.average) @@ -392,11 +400,13 @@ def forward(self, atom_feas: Tensor, atom_owner: Tensor) -> Tensor: Args: atom_feas (Tensor): batched atom features after convolution layers. - shape = [num_batch_atoms, atom_fea_dim] - atom_owner (Tensor): graph indices for each atom. shape = [num_batch_atoms] + [num_batch_atoms, atom_fea_dim] + atom_owner (Tensor): graph indices for each atom. + [num_batch_atoms] Returns: - crystal_feas (Tensor): crystal feature matrix. shape = [n_crystals, atom_fea_dim] + crystal_feas (Tensor): crystal feature matrix. + [n_crystals, atom_fea_dim] """ crystal_feas = [] weights = self.key(atom_feas) # [n_batch_atom, n_heads] diff --git a/chgnet/model/model.py b/chgnet/model/model.py index a6b79749..c36a287a 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -71,7 +71,9 @@ def __init__( Default = 64 composition_model (nn.Module, optional): attach a composition model to predict energy or initialize a pretrained linear regression (AtomRef). - Default = None + The default 'MPtrj' is the atom reference energy linear regression + trained on all Materials Project relaxation trajectories + Default = 'MPtrj' num_radial (int): number of radial basis used in bond basis expansion. Default = 9 num_angular (int): number of angular basis used in angle basis expansion. @@ -109,7 +111,9 @@ def __init__( non_linearity ('silu' | 'relu' | 'tanh' | 'gelu'): The name of the activation function to use in the gated MLP. Default = "silu". - mlp_first (bool): whether to apply mlp fist then pooling. + mlp_first (bool): whether to apply mlp first then pooling. + if set to True, then CHGNet is essentially calculating energy for each + atom, them sum them up, this is used for the pretrained model Default = True atom_graph_cutoff (float): cutoff radius (A) in creating atom_graph, this need to be consistent with the value in training dataloader @@ -539,7 +543,9 @@ def predict_graph( return_crystal_feas: bool = False, batch_size: int = 100, ) -> dict[str, Tensor]: - """Args: + """Predict from CrustalGraph. + + Args: graph (CrystalGraph): Crystal_Graph or a list of CrystalGraphs to predict. task (str): can be 'e' 'ef', 'em', 'efs', 'efsm' Default = "efsm" @@ -756,6 +762,9 @@ def from_graphs( directed2undirected.append(graph.directed2undirected + n_undirected) # Angles + # Here we use directed edges to calculate angles, and + # keep only the undirected graph index in the bond_graph, + # So the number of columns in bond_graph reduce from 5 to 3 if len(graph.bond_graph) != 0: bond_vecs_i = torch.index_select( bond_vectors, 0, graph.bond_graph[:, 2]