From cb9fd1d5ad432d7a807726fe080e030472aa8ec6 Mon Sep 17 00:00:00 2001 From: Shyue Ping Ong Date: Tue, 20 Jun 2023 10:12:43 -0700 Subject: [PATCH] Updated docs. --- changes.md | 3 + docs/changes.md | 3 + docs/matgl.apps.md | 6 +- docs/matgl.ext.md | 10 +- docs/matgl.graph.md | 22 +- docs/matgl.layers.md | 61 +++-- docs/matgl.md | 32 +-- docs/matgl.models.md | 6 +- docs/matgl.utils.md | 259 ++++++++++-------- docs/tutorials.md | 4 + ...M3GNet Potential with PyTorch Lightning.md | 149 ++++++++++ setup.py | 2 +- 12 files changed, 379 insertions(+), 178 deletions(-) create mode 100644 docs/tutorials/Training a M3GNet Potential with PyTorch Lightning.md diff --git a/changes.md b/changes.md index 6ced1053..ef00bd2a 100644 --- a/changes.md +++ b/changes.md @@ -6,6 +6,9 @@ nav_order: 3 # Change Log +## 0.6.0 +- Refactoring of training utilities. Added example for training an M3GNet potential. + ## 0.5.6 - Minor internal refactoring of basis expansions into `_basis.py`. (@lbluque) diff --git a/docs/changes.md b/docs/changes.md index 6ced1053..ef00bd2a 100644 --- a/docs/changes.md +++ b/docs/changes.md @@ -6,6 +6,9 @@ nav_order: 3 # Change Log +## 0.6.0 +- Refactoring of training utilities. Added example for training an M3GNet potential. + ## 0.5.6 - Minor internal refactoring of basis expansions into `_basis.py`. (@lbluque) diff --git a/docs/matgl.apps.md b/docs/matgl.apps.md index 56d7e579..dbabe743 100644 --- a/docs/matgl.apps.md +++ b/docs/matgl.apps.md @@ -14,7 +14,7 @@ potentials parameterizing the potential energy surface (PES). Implementation of Interatomic Potentials. -### _class_ matgl.apps.pes.Potential(model: nn.Module, data_mean: torch.tensor | None = None, data_std: torch.tensor | None = None, element_refs: np.ndarray | None = None, calc_forces: bool = True, calc_stresses: bool = True, calc_hessian: bool = False) +### _class_ matgl.apps.pes.Potential(model: nn.Module, data_mean: torch.Tensor | None = None, data_std: torch.Tensor | None = None, element_refs: np.ndarray | None = None, calc_forces: bool = True, calc_stresses: bool = True, calc_hessian: bool = False) Bases: `Module`, [`IOMixIn`](matgl.utils.md#matgl.utils.io.IOMixIn) A class representing an interatomic potential. @@ -47,7 +47,7 @@ Initialize Potential from a model and elemental references. -#### forward(g: dgl.DGLGraph, state_attr: torch.tensor | None = None, l_g: dgl.DGLGraph | None = None) +#### forward(g: dgl.DGLGraph, state_attr: torch.Tensor | None = None, l_g: dgl.DGLGraph | None = None) * **Parameters** @@ -64,7 +64,7 @@ Initialize Potential from a model and elemental references. * **Returns** - torch.tensor + torch.Tensor diff --git a/docs/matgl.ext.md b/docs/matgl.ext.md index fef8a8a9..4214392f 100644 --- a/docs/matgl.ext.md +++ b/docs/matgl.ext.md @@ -55,7 +55,7 @@ Get a DGL graph from an input Atoms. -### _class_ matgl.ext.ase.M3GNetCalculator(potential: [Potential](matgl.apps.md#matgl.apps.pes.Potential), state_attr: tensor | None = None, stress_weight: float = 1.0, \*\*kwargs) +### _class_ matgl.ext.ase.M3GNetCalculator(potential: [Potential](matgl.apps.md#matgl.apps.pes.Potential), state_attr: Tensor | None = None, stress_weight: float = 1.0, \*\*kwargs) Bases: `Calculator` M3GNet calculator for ASE. @@ -105,7 +105,7 @@ Perform calculation for an input Atoms. Properties calculator can handle (energy, forces, …) -### _class_ matgl.ext.ase.MolecularDynamics(atoms: Atoms, potential: [Potential](matgl.apps.md#matgl.apps.pes.Potential), state_attr: torch.tensor = None, ensemble: str = 'nvt', temperature: int = 300, timestep: float = 1.0, pressure: float = 6.324209121801212e-07, taut: float | None = None, taup: float | None = None, compressibility_au: float | None = None, trajectory: str | Trajectory | None = None, logfile: str | None = None, loginterval: int = 1, append_trajectory: bool = False) +### _class_ matgl.ext.ase.MolecularDynamics(atoms: Atoms, potential: [Potential](matgl.apps.md#matgl.apps.pes.Potential), state_attr: torch.Tensor = None, ensemble: str = 'nvt', temperature: int = 300, timestep: float = 1.0, pressure: float = 6.324209121801212e-07, taut: float | None = None, taup: float | None = None, compressibility_au: float | None = None, trajectory: str | Trajectory | None = None, logfile: str | None = None, loginterval: int = 1, append_trajectory: bool = False) Bases: `object` Molecular dynamics class. @@ -125,7 +125,7 @@ Init the MD simulation. * **atoms** – - * **state_attr** (*torch.tensor*) – State attr. + * **state_attr** (*torch.Tensor*) – State attr. * **ensemble** (*str*) – choose from ‘nvt’ or ‘npt’. NPT is not tested, @@ -185,7 +185,7 @@ Set new atoms to run MD. -### _class_ matgl.ext.ase.Relaxer(potential: [Potential](matgl.apps.md#matgl.apps.pes.Potential) = None, state_attr: torch.tensor = None, optimizer: Optimizer | str = 'FIRE', relax_cell: bool = True, stress_weight: float = 0.01) +### _class_ matgl.ext.ase.Relaxer(potential: [Potential](matgl.apps.md#matgl.apps.pes.Potential) = None, state_attr: torch.Tensor = None, optimizer: Optimizer | str = 'FIRE', relax_cell: bool = True, stress_weight: float = 0.01) Bases: `object` Relaxer is a class for structural relaxation. @@ -200,7 +200,7 @@ Relaxer is a class for structural relaxation. * **distribution** (*that comes with M3GNet*) – - * **state_attr** (*torch.tensor*) – State attr. + * **state_attr** (*torch.Tensor*) – State attr. * **optimizer** (*str** or **ase Optimizer*) – the optimization algorithm. diff --git a/docs/matgl.graph.md b/docs/matgl.graph.md index c1e9ca0d..268b4644 100644 --- a/docs/matgl.graph.md +++ b/docs/matgl.graph.md @@ -57,8 +57,8 @@ Args: edges: DGL graph edges Returns: -cos_theta: torch.tensor -phi: torch.tensor +cos_theta: torch.Tensor +phi: torch.Tensor triple_bond_lengths (torch.tensor): @@ -140,10 +140,20 @@ Returns: True if file exists. #### load(filename: str = 'dgl_graph.bin', filename_line_graph: str = 'dgl_line_graph.bin', filename_state_attr: str = 'state_attr.pt') -Load dgl graphs -Args: -:filename: Name of file storing dgl graphs -:filename: Name of file storing state attrs. +Load dgl graphs from files. + + +* **Parameters** + + + * **filename** – Name of file storing dgl graphs + + + * **filename_line_graph** – Name of file storing dgl line graphs + + + * **filename_state_attr** – Name of file storing state attrs. + #### process() diff --git a/docs/matgl.layers.md b/docs/matgl.layers.md index 6859bc35..f1d56eb8 100644 --- a/docs/matgl.layers.md +++ b/docs/matgl.layers.md @@ -32,7 +32,7 @@ Init SoftExponential with alpha value. -#### forward(x: tensor) +#### forward(x: Tensor) Evaluate activation function given the input tensor x. @@ -66,7 +66,7 @@ softplus function that is 0 at x=0, the implementation aims at avoiding overflow Initializes the SoftPlus2 class. -#### forward(x: tensor) +#### forward(x: Tensor) Evaluate activation function given the input tensor x. @@ -99,50 +99,51 @@ Bases: `Module` Get total property offset for a system. -### Args: - +Args: property_offset (np.array): a array of elemental property offset. -#### fit(structs_or_graphs: list, element_list: tuple[str], properties: np.typing.NDArray) +#### fit(graphs: list, properties: np.typing.NDArray) Fit the elemental reference values for the properties. * **Parameters** - * **structs_or_graphs** – pymatgen Structures or dgl graphs - - - * **element_list** (*tuple*) – a list of element types + * **graphs** – dgl graphs * **properties** (*np.ndarray*) – array of extensive properties -#### forward(g: dgl.DGLGraph, state_attr: torch.tensor | None = None) +#### forward(g: dgl.DGLGraph, state_attr: torch.Tensor | None = None) Get the total property offset for a system. -Args: -g: a batch of dgl graphs -state_attr: state attributes -Returns: -offset_per_graph: +* **Parameters** + + * **g** – a batch of dgl graphs -#### get_feature_matrix(structs_or_graphs: list, element_list: tuple[str]) -Get the number of atoms for different elements in the structure. + * **state_attr** – state attributes -* **Parameters** - - * **structs_or_graphs** (*list*) – a list of pymatgen Structure or dgl graph + +* **Returns** + + offset_per_graph + + + +#### get_feature_matrix(graphs: list) +Get the number of atoms for different elements in the structure. - * **element_list** – a dictionary containing element types in the training set +* **Parameters** + + **graphs** (*list*) – a list of dgl graph @@ -228,7 +229,7 @@ vanishes at cutoff. * **Parameters** - * **r** – torch.tensor pytorch tensors + * **r** – torch.Tensor pytorch tensors * **cutoff** – float, the cutoff radius @@ -305,7 +306,7 @@ Ref: * **Parameters** - * **r** – torch.tensor distance tensor + * **r** – torch.Tensor distance tensor * **cutoff** – float, cutoff radius @@ -338,7 +339,7 @@ num_centers (int): Number of centers for gaussian expansion. width (float): width of gaussian function. -#### forward(bond_dist: tensor) +#### forward(bond_dist: Tensor) Forward. Args: @@ -373,7 +374,7 @@ Implementation of Set2Set. -#### forward(g: DGLGraph, feat: tensor) +#### forward(g: DGLGraph, feat: Tensor) Defines the computation performed at every call. @@ -421,7 +422,7 @@ An implementation of a Gated multi-layer perceptron. -#### forward(inputs: tensor) +#### forward(inputs: Tensor) Defines the computation performed at every call. Should be overridden by all subclasses. @@ -993,7 +994,7 @@ num_targets: number of target properties. * **Returns** - torch.tensor. + torch.Tensor. @@ -1051,7 +1052,7 @@ Init ThreeBodyInteractions. -#### forward(graph: DGLGraph, line_graph: DGLGraph, three_basis: tensor, three_cutoff: float, node_feat: tensor, edge_feat: tensor) +#### forward(graph: DGLGraph, line_graph: DGLGraph, three_basis: Tensor, three_cutoff: float, node_feat: Tensor, edge_feat: Tensor) Forward function for ThreeBodyInteractions. @@ -1097,10 +1098,10 @@ For the spherical Harmonics function, the column is ordered by * **Parameters** - * **sbf** – torch.tensor spherical bessel function results + * **sbf** – torch.Tensor spherical bessel function results - * **shf** – torch.tensor spherical harmonics function results + * **shf** – torch.Tensor spherical harmonics function results * **max_n** – int, max number of n diff --git a/docs/matgl.md b/docs/matgl.md index 997d691f..8a00752f 100644 --- a/docs/matgl.md +++ b/docs/matgl.md @@ -588,52 +588,52 @@ MatGL (Materials Graph Library) is a graph deep learning library for materials s * [matgl.utils.training module](matgl.utils.md#module-matgl.utils.training) - * [`ModelTrainer`](matgl.utils.md#matgl.utils.training.ModelTrainer) + * [`MatglLightningModuleMixin`](matgl.utils.md#matgl.utils.training.MatglLightningModuleMixin) - * [`ModelTrainer.forward()`](matgl.utils.md#matgl.utils.training.ModelTrainer.forward) + * [`MatglLightningModuleMixin.configure_optimizers()`](matgl.utils.md#matgl.utils.training.MatglLightningModuleMixin.configure_optimizers) - * [`ModelTrainer.loss_fn()`](matgl.utils.md#matgl.utils.training.ModelTrainer.loss_fn) + * [`MatglLightningModuleMixin.on_test_model_eval()`](matgl.utils.md#matgl.utils.training.MatglLightningModuleMixin.on_test_model_eval) - * [`ModelTrainer.step()`](matgl.utils.md#matgl.utils.training.ModelTrainer.step) + * [`MatglLightningModuleMixin.on_train_epoch_end()`](matgl.utils.md#matgl.utils.training.MatglLightningModuleMixin.on_train_epoch_end) - * [`PotentialTrainer`](matgl.utils.md#matgl.utils.training.PotentialTrainer) + * [`MatglLightningModuleMixin.predict_step()`](matgl.utils.md#matgl.utils.training.MatglLightningModuleMixin.predict_step) - * [`PotentialTrainer.forward()`](matgl.utils.md#matgl.utils.training.PotentialTrainer.forward) + * [`MatglLightningModuleMixin.test_step()`](matgl.utils.md#matgl.utils.training.MatglLightningModuleMixin.test_step) - * [`PotentialTrainer.loss_fn()`](matgl.utils.md#matgl.utils.training.PotentialTrainer.loss_fn) + * [`MatglLightningModuleMixin.training_step()`](matgl.utils.md#matgl.utils.training.MatglLightningModuleMixin.training_step) - * [`PotentialTrainer.step()`](matgl.utils.md#matgl.utils.training.PotentialTrainer.step) + * [`MatglLightningModuleMixin.validation_step()`](matgl.utils.md#matgl.utils.training.MatglLightningModuleMixin.validation_step) - * [`TrainerMixin`](matgl.utils.md#matgl.utils.training.TrainerMixin) + * [`ModelLightningModule`](matgl.utils.md#matgl.utils.training.ModelLightningModule) - * [`TrainerMixin.configure_optimizers()`](matgl.utils.md#matgl.utils.training.TrainerMixin.configure_optimizers) + * [`ModelLightningModule.forward()`](matgl.utils.md#matgl.utils.training.ModelLightningModule.forward) - * [`TrainerMixin.on_test_model_eval()`](matgl.utils.md#matgl.utils.training.TrainerMixin.on_test_model_eval) + * [`ModelLightningModule.loss_fn()`](matgl.utils.md#matgl.utils.training.ModelLightningModule.loss_fn) - * [`TrainerMixin.on_train_epoch_end()`](matgl.utils.md#matgl.utils.training.TrainerMixin.on_train_epoch_end) + * [`ModelLightningModule.step()`](matgl.utils.md#matgl.utils.training.ModelLightningModule.step) - * [`TrainerMixin.predict_step()`](matgl.utils.md#matgl.utils.training.TrainerMixin.predict_step) + * [`PotentialLightningModule`](matgl.utils.md#matgl.utils.training.PotentialLightningModule) - * [`TrainerMixin.test_step()`](matgl.utils.md#matgl.utils.training.TrainerMixin.test_step) + * [`PotentialLightningModule.forward()`](matgl.utils.md#matgl.utils.training.PotentialLightningModule.forward) - * [`TrainerMixin.training_step()`](matgl.utils.md#matgl.utils.training.TrainerMixin.training_step) + * [`PotentialLightningModule.loss_fn()`](matgl.utils.md#matgl.utils.training.PotentialLightningModule.loss_fn) - * [`TrainerMixin.validation_step()`](matgl.utils.md#matgl.utils.training.TrainerMixin.validation_step) + * [`PotentialLightningModule.step()`](matgl.utils.md#matgl.utils.training.PotentialLightningModule.step) * [`xavier_init()`](matgl.utils.md#matgl.utils.training.xavier_init) diff --git a/docs/matgl.models.md b/docs/matgl.models.md index 2ffaa050..f69bae9b 100644 --- a/docs/matgl.models.md +++ b/docs/matgl.models.md @@ -116,7 +116,7 @@ The main M3GNet model. -#### forward(g: dgl.DGLGraph, state_attr: torch.tensor | None = None, l_g: dgl.DGLGraph | None = None) +#### forward(g: dgl.DGLGraph, state_attr: torch.Tensor | None = None, l_g: dgl.DGLGraph | None = None) Performs message passing and updates node representations. @@ -145,7 +145,7 @@ Performs message passing and updates node representations. -#### predict_structure(structure, state_feats: torch.tensor | None = None, graph_converter: [GraphConverter](matgl.graph.md#matgl.graph.converters.GraphConverter) | None = None) +#### predict_structure(structure, state_feats: torch.Tensor | None = None, graph_converter: [GraphConverter](matgl.graph.md#matgl.graph.converters.GraphConverter) | None = None) Convenience method to directly predict property from structure. @@ -297,7 +297,7 @@ Forward pass of MEGnet. Executes all blocks. -#### predict_structure(structure, state_feats: torch.tensor | None = None, graph_converter: [GraphConverter](matgl.graph.md#matgl.graph.converters.GraphConverter) | None = None) +#### predict_structure(structure, state_feats: torch.Tensor | None = None, graph_converter: [GraphConverter](matgl.graph.md#matgl.graph.converters.GraphConverter) | None = None) Convenience method to directly predict property from structure. diff --git a/docs/matgl.utils.md b/docs/matgl.utils.md index c18d3842..5fe0653e 100644 --- a/docs/matgl.utils.md +++ b/docs/matgl.utils.md @@ -176,7 +176,7 @@ Convenience method to load a model from a directory or name. Implementations of math functions. -### matgl.utils.maths.broadcast(input_tensor: tensor, target_tensor: tensor, dim: int) +### matgl.utils.maths.broadcast(input_tensor: Tensor, target_tensor: Tensor, dim: int) Broadcast input tensor along a given dimension to match the shape of the target tensor. Modified from torch_scatter library ([https://github.com/rusty1s/pytorch_scatter](https://github.com/rusty1s/pytorch_scatter)). @@ -240,7 +240,7 @@ Give ns = [2, 3], return [0, 1, 0, 1, 2]. * **Parameters** - **ns** – torch.tensor, the number of atoms/bonds array + **ns** – torch.Tensor, the number of atoms/bonds array Returns: range indices @@ -253,7 +253,7 @@ ns = [2, 3], then the function will return [0, 0, 1, 1, 1]. * **Parameters** - **ns** – torch.tensor, the number of atoms/bonds array + **ns** – torch.Tensor, the number of atoms/bonds array @@ -281,7 +281,7 @@ Repeat the first dimension according to n array. Returns: repeated tensor -### matgl.utils.maths.scatter_sum(input_tensor: tensor, segment_ids: tensor, num_segments: int, dim: int) +### matgl.utils.maths.scatter_sum(input_tensor: Tensor, segment_ids: Tensor, num_segments: int, dim: int) Scatter sum operation along the specified dimension. Modified from the torch_scatter library ([https://github.com/rusty1s/pytorch_scatter](https://github.com/rusty1s/pytorch_scatter)). @@ -329,14 +329,14 @@ roots for j0, i.e., sinc(x). Returns: root matrix of size [max_l, max_n] -### matgl.utils.maths.unsorted_segment_fraction(data: tensor, segment_ids: tensor, num_segments: tensor) +### matgl.utils.maths.unsorted_segment_fraction(data: Tensor, segment_ids: Tensor, num_segments: int) Segment fraction :param data: original data :type data: torch.tensor :param segment_ids: segment ids :type segment_ids: torch.tensor :param num_segments: number of segments -:type num_segments: torch.tensor +:type num_segments: int * **Returns** @@ -355,118 +355,123 @@ Segment fraction Utils for training MatGL models. -### _class_ matgl.utils.training.ModelTrainer(model, data_mean=None, data_std=None, loss: str = 'mse_loss', optimizer: Optimizer | None = None, scheduler: lr_scheduler | None = None, lr: float = 0.001, decay_steps: int = 1000, decay_alpha: float = 0.01) -Bases: `TrainerMixin`, `LightningModule` - -Trainer for MEGNet and M3GNet models. +### _class_ matgl.utils.training.MatglLightningModuleMixin() +Bases: `object` -Args: -model: Which type of the model for training -data_mean: average of training data -data_std: standard deviation of training data -loss: loss function used for training -optimizer: optimizer for training -scheduler: scheduler for training -lr: learning rate for training -decay_steps: number of steps for decaying learning rate -decay_alpha: parameter determines the minimum learning rate. +Mix-in class implementing common functions for training. -#### forward(g: dgl.DGLGraph, l_g: dgl.DGLGraph | None = None, state_attr: torch.tensor | None = None) +#### configure_optimizers() +Configure optimizers. -* **Parameters** - - * **g** – dgl Graph +#### on_test_model_eval(\*args, \*\*kwargs) +Executed on model testing. - * **l_g** – Line graph +* **Parameters** + + * **\*args** – Pass-through - * **state_attr** – State attribute. + * **\*\*kwargs** – Pass-through. -* **Returns** - Model prediction. +#### on_train_epoch_end() +Step scheduler every epoch. +#### predict_step(batch, batch_idx, dataloader_idx=0) +Prediction step. -#### loss_fn(loss: Module, labels: tuple, preds: tuple) * **Parameters** - * **loss** – Loss function. + * **batch** – Data batch. - * **labels** – Labels to compute the loss. + * **batch_idx** – Batch index. - * **preds** – Predictions. + * **dataloader_idx** – Data loader index. * **Returns** - total_loss, “MAE”: mae, “RMSE”: rmse} + Prediction -* **Return type** +#### test_step(batch: tuple, batch_idx: int) +Test step. - {“Total_Loss” +* **Parameters** + + * **batch** – Data batch. + + + * **batch_idx** – Batch index. + + + +#### training_step(batch: tuple, batch_idx: int) +Training step. -#### step(batch: tuple) * **Parameters** - **batch** – Batch of training data. + + * **batch** – Data batch. + * **batch_idx** – Batch index. -* **Returns** - results, batch_size +* **Returns** + Total loss. -### _class_ matgl.utils.training.PotentialTrainer(model, element_refs: np.darray | None = None, energy_weight: float = 1.0, force_weight: float = 1.0, stress_weight: float | None = None, data_mean=None, data_std=None, calc_stress: bool = False, loss: str = 'mse_loss', optimizer: Optimizer | None = None, scheduler: lr_scheduler | None = None, lr: float = 0.001, decay_steps: int = 1000, decay_alpha: float = 0.01) -Bases: `TrainerMixin`, `LightningModule` -Trainer for MatGL potentials. -Init PotentialTrainer with key parameters. +#### validation_step(batch: tuple, batch_idx: int) +Validation step. * **Parameters** - * **model** – Which type of the model for training + * **batch** – Data batch. - * **element_refs** – element offset for PES + * **batch_idx** – Batch index. - * **energy_weight** – relative importance of energy +### _class_ matgl.utils.training.ModelLightningModule(model, data_mean=None, data_std=None, loss: str = 'mse_loss', optimizer: Optimizer | None = None, scheduler=None, lr: float = 0.001, decay_steps: int = 1000, decay_alpha: float = 0.01, \*\*kwargs) +Bases: `MatglLightningModuleMixin`, `LightningModule` - * **force_weight** – relative importance of force +A PyTorch.LightningModule for training MEGNet and M3GNet models. +Init ModelLightningModule with key parameters. - * **stress_weight** – relative importance of stress +* **Parameters** - * **data_mean** – average of training data + + * **model** – Which type of the model for training - * **data_std** – standard deviation of training data + * **data_mean** – average of training data - * **calc_stress** – whether stress calculation is required + * **data_std** – standard deviation of training data * **loss** – loss function used for training @@ -487,8 +492,11 @@ Init PotentialTrainer with key parameters. * **decay_alpha** – parameter determines the minimum learning rate. + * **\*\*kwargs** – Passthrough to parent init. + + -#### forward(g: dgl.DGLGraph, l_g: dgl.DGLGraph | None = None, state_attr: torch.tensor | None = None) +#### forward(g: dgl.DGLGraph, l_g: dgl.DGLGraph | None = None, state_attr: torch.Tensor | None = None) * **Parameters** @@ -499,19 +507,17 @@ Init PotentialTrainer with key parameters. * **l_g** – Line graph - * **state_attr** – State attr. + * **state_attr** – State attribute. * **Returns** - energy, force, stress, h + Model prediction. -#### loss_fn(loss: nn.Module, labels: tuple, preds: tuple, energy_weight: float | None = None, force_weight: float | None = None, stress_weight: float | None = None, num_atoms: int | None = None) -Compute losses for EFS. - +#### loss_fn(loss: Module, labels: tuple, preds: tuple) * **Parameters** @@ -519,38 +525,24 @@ Compute losses for EFS. * **loss** – Loss function. - * **labels** – Labels. + * **labels** – Labels to compute the loss. - * **preds** – Predictions + * **preds** – Predictions. - * **energy_weight** – Weight for energy loss. +* **Returns** - * **force_weight** – Weight for force loss. + total_loss, “MAE”: mae, “RMSE”: rmse} - * **stress_weight** – Weight for stress loss. +* **Return type** - * **num_atoms** – Number of atoms. + {“Total_Loss” -Returns: - -```default -{ - "Total_Loss": total_loss, - "Energy_MAE": e_mae, - "Force_MAE": f_mae, - "Stress_MAE": s_mae, - "Energy_RMSE": e_rmse, - "Force_RMSE": f_rmse, - "Stress_RMSE": s_rmse, -} -``` - #### step(batch: tuple) @@ -566,102 +558,141 @@ Returns: -### _class_ matgl.utils.training.TrainerMixin() -Bases: `object` +### _class_ matgl.utils.training.PotentialLightningModule(model, element_refs: np.ndarray | None = None, energy_weight: float = 1.0, force_weight: float = 1.0, stress_weight: float | None = None, data_mean=None, data_std=None, calc_stress: bool = False, loss: str = 'mse_loss', optimizer: Optimizer | None = None, scheduler=None, lr: float = 0.001, decay_steps: int = 1000, decay_alpha: float = 0.01, \*\*kwargs) +Bases: `MatglLightningModuleMixin`, `LightningModule` -Mix-in class implementing common functions for training. - - -#### configure_optimizers() -Configure optimizers. +A PyTorch.LightningModule for training MatGL potentials. +This is slightly different from the ModelLightningModel due to the need to account for energy, forces and stress +losses. -#### on_test_model_eval(\*args, \*\*kwargs) -Executed on model testing. +Init PotentialLightningModule with key parameters. * **Parameters** - * **\*args** – Pass-through + * **model** – Which type of the model for training - * **\*\*kwargs** – Pass-through. + * **element_refs** – element offset for PES + * **energy_weight** – relative importance of energy -#### on_train_epoch_end() -Step scheduler every epoch. + * **force_weight** – relative importance of force -#### predict_step(batch, batch_idx, dataloader_idx=0) -Prediction step. + * **stress_weight** – relative importance of stress -* **Parameters** - - * **batch** – Data batch. + * **data_mean** – average of training data - * **batch_idx** – Batch index. + * **data_std** – standard deviation of training data - * **dataloader_idx** – Data loader index. + * **calc_stress** – whether stress calculation is required + * **loss** – loss function used for training -* **Returns** - Prediction + * **optimizer** – optimizer for training + * **scheduler** – scheduler for training -#### test_step(batch: tuple, batch_idx: int) -Test step. + * **lr** – learning rate for training -* **Parameters** - - * **batch** – Data batch. + * **decay_steps** – number of steps for decaying learning rate - * **batch_idx** – Batch index. + * **decay_alpha** – parameter determines the minimum learning rate. + * **\*\*kwargs** – Passthrough to parent init. + -#### training_step(batch: tuple, batch_idx: int) -Training step. +#### forward(g: dgl.DGLGraph, l_g: dgl.DGLGraph | None = None, state_attr: torch.Tensor | None = None) * **Parameters** - * **batch** – Data batch. + * **g** – dgl Graph - * **batch_idx** – Batch index. + * **l_g** – Line graph + + + * **state_attr** – State attr. * **Returns** - Total loss. + energy, force, stress, h -#### validation_step(batch: tuple, batch_idx: int) -Validation step. +#### loss_fn(loss: nn.Module, labels: tuple, preds: tuple, energy_weight: float | None = None, force_weight: float | None = None, stress_weight: float | None = None, num_atoms: int | None = None) +Compute losses for EFS. * **Parameters** - * **batch** – Data batch. + * **loss** – Loss function. - * **batch_idx** – Batch index. + * **labels** – Labels. + + + * **preds** – Predictions + + + * **energy_weight** – Weight for energy loss. + + + * **force_weight** – Weight for force loss. + + + * **stress_weight** – Weight for stress loss. + + + * **num_atoms** – Number of atoms. + + +Returns: + +```default +{ + "Total_Loss": total_loss, + "Energy_MAE": e_mae, + "Force_MAE": f_mae, + "Stress_MAE": s_mae, + "Energy_RMSE": e_rmse, + "Force_RMSE": f_rmse, + "Stress_RMSE": s_rmse, +} +``` + + +#### step(batch: tuple) + +* **Parameters** + + **batch** – Batch of training data. + + + +* **Returns** + + results, batch_size diff --git a/docs/tutorials.md b/docs/tutorials.md index 775af4cc..dc93874c 100644 --- a/docs/tutorials.md +++ b/docs/tutorials.md @@ -21,4 +21,8 @@ This series of notebooks demonstrate how to load and use the pretrained models f 1. [Benchmarking M3GNet Predictions of Cubic Lattice Parameters](tutorials%2FBenchmarking%20M3GNet%20Predictions%20of%20Cubic%20Lattice%20Parameters.html) +## Training MatGL models + +1. [Training a M3GNet Potential with PyTorch Lightning](tutorials%2FTraining%20a%20M3GNet%20Potential%20with%20PyTorch%20Lightning.html) + [jupyternb]: https://github.com/materialsvirtuallab/matgl/tree/main/examples "Jupyter notebooks" \ No newline at end of file diff --git a/docs/tutorials/Training a M3GNet Potential with PyTorch Lightning.md b/docs/tutorials/Training a M3GNet Potential with PyTorch Lightning.md new file mode 100644 index 00000000..f527d2ef --- /dev/null +++ b/docs/tutorials/Training a M3GNet Potential with PyTorch Lightning.md @@ -0,0 +1,149 @@ +--- +layout: default +title: tutorials/Training a M3GNet Potential with PyTorch Lightning.md +nav_exclude: true +--- +# Introduction + +This notebook demonstrates how to fit a M3GNet potential using PyTorch Lightning with MatGL. + + +```python +from __future__ import annotations + +import os +import shutil +import warnings + +import numpy as np +import pytorch_lightning as pl +from dgl.data.utils import split_dataset +from pymatgen.ext.matproj import MPRester + +from matgl.ext.pymatgen import Structure2Graph, get_element_list +from matgl.graph.data import M3GNetDataset, MGLDataLoader, collate_fn_efs +from matgl.models import M3GNet +from matgl.utils.training import PotentialLightningModule + +# To suppress warnings for clearer output +warnings.simplefilter("ignore") +``` + +For the purposes of demonstration, we will download all Si-O compounds in the Materials Project via the MPRester. The forces and stresses are set to zero, though in a real context, these would be non-zero and obtained from DFT calculations. + + +```python +mpr = MPRester() + +entries = mpr.get_entries_in_chemsys(["Si", "O"]) +structures = [e.structure for e in entries] +energies = [e.energy for e in entries] +forces = [np.zeros((len(s), 3)).tolist() for s in structures] +stresses = [np.zeros((3, 3)).tolist() for s in structures] + +print(f"{len(structures)} downloaded from MP.") +``` + + + Retrieving ThermoDoc documents: 0%| | 0/407 [00:00