From 8cab53b4a1846731cff9e4bfb5b472fd422063d8 Mon Sep 17 00:00:00 2001 From: Shyue Ping Ong Date: Tue, 20 Jun 2023 20:33:58 -0700 Subject: [PATCH] Updated docs. --- changes.md | 3 + docs/changes.md | 3 + docs/matgl.ext.md | 8 +- docs/matgl.models.md | 6 +- docs/matgl.utils.md | 2 +- docs/tutorials.md | 3 +- ...M3GNet Potential with PyTorch Lightning.md | 2 +- ...ion Energy Model with PyTorch Lightning.md | 591 ++++++++++++++++++ setup.py | 2 +- 9 files changed, 607 insertions(+), 13 deletions(-) create mode 100644 docs/tutorials/Training a MEGNet Formation Energy Model with PyTorch Lightning.md diff --git a/changes.md b/changes.md index ef00bd2a..416eb3bb 100644 --- a/changes.md +++ b/changes.md @@ -6,6 +6,9 @@ nav_order: 3 # Change Log +## 0.6.1 +- Bug fix for training loss_fn. + ## 0.6.0 - Refactoring of training utilities. Added example for training an M3GNet potential. diff --git a/docs/changes.md b/docs/changes.md index ef00bd2a..416eb3bb 100644 --- a/docs/changes.md +++ b/docs/changes.md @@ -6,6 +6,9 @@ nav_order: 3 # Change Log +## 0.6.1 +- Bug fix for training loss_fn. + ## 0.6.0 - Refactoring of training utilities. Added example for training an M3GNet potential. diff --git a/docs/matgl.ext.md b/docs/matgl.ext.md index 4214392f..63d02981 100644 --- a/docs/matgl.ext.md +++ b/docs/matgl.ext.md @@ -55,7 +55,7 @@ Get a DGL graph from an input Atoms. -### _class_ matgl.ext.ase.M3GNetCalculator(potential: [Potential](matgl.apps.md#matgl.apps.pes.Potential), state_attr: Tensor | None = None, stress_weight: float = 1.0, \*\*kwargs) +### _class_ matgl.ext.ase.M3GNetCalculator(potential: [Potential](matgl.apps.md#matgl.apps.pes.Potential), state_attr: torch.Tensor | None = None, stress_weight: float = 1.0, \*\*kwargs) Bases: `Calculator` M3GNet calculator for ASE. @@ -101,11 +101,11 @@ Perform calculation for an input Atoms. -#### implemented_properties(_: List[str_ _ = ['energy', 'free_energy', 'forces', 'stress', 'hessian'_ ) +#### implemented_properties(_: List[str_ _ = ('energy', 'free_energy', 'forces', 'stress', 'hessian'_ ) Properties calculator can handle (energy, forces, …) -### _class_ matgl.ext.ase.MolecularDynamics(atoms: Atoms, potential: [Potential](matgl.apps.md#matgl.apps.pes.Potential), state_attr: torch.Tensor = None, ensemble: str = 'nvt', temperature: int = 300, timestep: float = 1.0, pressure: float = 6.324209121801212e-07, taut: float | None = None, taup: float | None = None, compressibility_au: float | None = None, trajectory: str | Trajectory | None = None, logfile: str | None = None, loginterval: int = 1, append_trajectory: bool = False) +### _class_ matgl.ext.ase.MolecularDynamics(atoms: Atoms, potential: [Potential](matgl.apps.md#matgl.apps.pes.Potential), state_attr: torch.Tensor | None = None, ensemble: str = 'nvt', temperature: int = 300, timestep: float = 1.0, pressure: float = 6.324209121801212e-07, taut: float | None = None, taup: float | None = None, compressibility_au: float | None = None, trajectory: str | Trajectory | None = None, logfile: str | None = None, loginterval: int = 1, append_trajectory: bool = False) Bases: `object` Molecular dynamics class. @@ -185,7 +185,7 @@ Set new atoms to run MD. -### _class_ matgl.ext.ase.Relaxer(potential: [Potential](matgl.apps.md#matgl.apps.pes.Potential) = None, state_attr: torch.Tensor = None, optimizer: Optimizer | str = 'FIRE', relax_cell: bool = True, stress_weight: float = 0.01) +### _class_ matgl.ext.ase.Relaxer(potential: [Potential](matgl.apps.md#matgl.apps.pes.Potential) | None = None, state_attr: torch.Tensor | None = None, optimizer: Optimizer | str = 'FIRE', relax_cell: bool = True, stress_weight: float = 0.01) Bases: `object` Relaxer is a class for structural relaxation. diff --git a/docs/matgl.models.md b/docs/matgl.models.md index f69bae9b..44e385af 100644 --- a/docs/matgl.models.md +++ b/docs/matgl.models.md @@ -188,7 +188,7 @@ Molecules and Crystals._ Chem. Mater. 2019, 31 (9), 3564-3572. DOI: 10.1021/acs. ``` -### _class_ matgl.models._megnet.MEGNet(dim_node_embedding: int = 16, dim_edge_embedding: int = 100, dim_state_embedding: int = 2, ntypes_state: int | None = None, nblocks: int = 3, hidden_layer_sizes_input: tuple[int, ...] = (64, 32), hidden_layer_sizes_conv: tuple[int, ...] = (64, 64, 32), hidden_layer_sizes_output: tuple[int, ...] = (32, 16), nlayers_set2set: int = 1, niters_set2set: int = 2, activation_type: str = 'softplus2', is_classification: bool = False, include_state: bool = True, dropout: float | None = None, graph_transformations: list | None = None, element_types: tuple[str, ...] = ('H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu'), bond_expansion: [BondExpansion](matgl.layers.md#matgl.layers._bond.BondExpansion) | None = None, cutoff: float = 4.0, gauss_width: float = 0.5, \*\*kwargs) +### _class_ matgl.models._megnet.MEGNet(dim_node_embedding: int = 16, dim_edge_embedding: int = 100, dim_state_embedding: int = 2, ntypes_state: int | None = None, nblocks: int = 3, hidden_layer_sizes_input: tuple[int, ...] = (64, 32), hidden_layer_sizes_conv: tuple[int, ...] = (64, 64, 32), hidden_layer_sizes_output: tuple[int, ...] = (32, 16), nlayers_set2set: int = 1, niters_set2set: int = 2, activation_type: str = 'softplus2', is_classification: bool = False, include_state: bool = True, dropout: float | None = None, element_types: tuple[str, ...] = ('H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu'), bond_expansion: [BondExpansion](matgl.layers.md#matgl.layers._bond.BondExpansion) | None = None, cutoff: float = 4.0, gauss_width: float = 0.5, \*\*kwargs) Bases: `Module`, [`IOMixIn`](matgl.utils.md#matgl.utils.io.IOMixIn) DGL implementation of MEGNet. @@ -251,10 +251,6 @@ Useful defaults for all arguments have been specified based on MEGNet formation a Bernoulli distribution - * **graph_transformations** – Perform a graph transformation, e.g., incorporate three-body interactions, prior to - performing the GCL updates. - - * **element_types** – Elements included in the training set diff --git a/docs/matgl.utils.md b/docs/matgl.utils.md index 5fe0653e..11c87ed1 100644 --- a/docs/matgl.utils.md +++ b/docs/matgl.utils.md @@ -517,7 +517,7 @@ Init ModelLightningModule with key parameters. -#### loss_fn(loss: Module, labels: tuple, preds: tuple) +#### loss_fn(loss: Module, labels: Tensor, preds: Tensor) * **Parameters** diff --git a/docs/tutorials.md b/docs/tutorials.md index dc93874c..b370692c 100644 --- a/docs/tutorials.md +++ b/docs/tutorials.md @@ -23,6 +23,7 @@ This series of notebooks demonstrate how to load and use the pretrained models f ## Training MatGL models -1. [Training a M3GNet Potential with PyTorch Lightning](tutorials%2FTraining%20a%20M3GNet%20Potential%20with%20PyTorch%20Lightning.html) +1. [Training a MEGNet Formation Energy Model](tutorials%2FTraining%20a%20MEGNet%20Formation%20Energy%20Model%20with%20PyTorch%20Lightning.html) +2. [Training a M3GNet Potential](tutorials%2FTraining%20a%20M3GNet%20Potential%20with%20PyTorch%20Lightning.html) [jupyternb]: https://github.com/materialsvirtuallab/matgl/tree/main/examples "Jupyter notebooks" \ No newline at end of file diff --git a/docs/tutorials/Training a M3GNet Potential with PyTorch Lightning.md b/docs/tutorials/Training a M3GNet Potential with PyTorch Lightning.md index f527d2ef..bf8ae6b8 100644 --- a/docs/tutorials/Training a M3GNet Potential with PyTorch Lightning.md +++ b/docs/tutorials/Training a M3GNet Potential with PyTorch Lightning.md @@ -89,7 +89,7 @@ lit_module = PotentialLightningModule(model=model) 100%|███████████████████████████████████████████████████████████████████████████████████████| 407/407 [00:02<00:00, 202.56it/s] -Finally, we will initialize the Pytorch Lightning trainer and run the fitting. Here, the max_epochs is set to 2 just for demonstration purposes. In a real fitting, this would be a much larger number. Also, the `accelerator` +Finally, we will initialize the Pytorch Lightning trainer and run the fitting. Here, the max_epochs is set to 2 just for demonstration purposes. In a real fitting, this would be a much larger number. Also, the `accelerator="cpu"` was set just to ensure compatibility with M1 Macs. In a real world use case, please remove the kwarg or set it to cuda for GPU based training. ```python diff --git a/docs/tutorials/Training a MEGNet Formation Energy Model with PyTorch Lightning.md b/docs/tutorials/Training a MEGNet Formation Energy Model with PyTorch Lightning.md new file mode 100644 index 00000000..9adca5ef --- /dev/null +++ b/docs/tutorials/Training a MEGNet Formation Energy Model with PyTorch Lightning.md @@ -0,0 +1,591 @@ +--- +layout: default +title: tutorials/Training a MEGNet Formation Energy Model with PyTorch Lightning.md +nav_exclude: true +--- +# Introduction + +This notebook demonstrates how to refit a MEGNet formation energy model using PyTorch Lightning with MatGL. + + +```python +from __future__ import annotations + +import os +import shutil +import warnings +import zipfile + +import pandas as pd +import pytorch_lightning as pl +import torch +from dgl.data.utils import split_dataset +from pymatgen.core import Structure +from tqdm import tqdm + +from matgl.ext.pymatgen import Structure2Graph, get_element_list +from matgl.graph.data import MEGNetDataset, MGLDataLoader, collate_fn +from matgl.layers import BondExpansion +from matgl.models import MEGNet +from matgl.utils.io import RemoteFile +from matgl.utils.training import ModelLightningModule + +# To suppress warnings for clearer output +warnings.simplefilter("ignore") +``` + +We will download the original dataset used in the training of the MEGNet formation energy model (MP.2018.6.1) from figshare. To make it easier, we will also cache the data. + + +```python +def load_dataset() -> tuple[list[Structure], list[str], list[float]]: + """Raw data loading function. + + Returns: + tuple[list[Structure], list[str], list[float]]: structures, mp_id, Eform_per_atom + """ + if not os.path.exists("mp.2018.6.1.json"): + f = RemoteFile("https://figshare.com/ndownloader/files/15087992") + with zipfile.ZipFile(f.local_path) as zf: + zf.extractall(".") + data = pd.read_json("mp.2018.6.1.json") + structures = [] + mp_ids = [] + for mid, structure_str in tqdm(zip(data["material_id"], data["structure"])): + struct = Structure.from_str(structure_str, fmt="cif") + structures.append(struct) + mp_ids.append(mid) + + return structures, mp_ids, data["formation_energy_per_atom"].tolist() +``` + + +```python +# load the MP raw dataset +structures, mp_ids, eform_per_atom = load_dataset() + +# For demo purposes, we are only going to select 100 structures from the entire set of structures. +structures = structures[:100] +eform_per_atom = eform_per_atom[:100] +``` + + 69239it [02:55, 395.06it/s] + + +Here, we set up the dataset. + + +```python +# get element types in the dataset +elem_list = get_element_list(structures) +# setup a graph converter +converter = Structure2Graph(element_types=elem_list, cutoff=4.0) +# convert the raw dataset into MEGNetDataset +mp_dataset = MEGNetDataset( + structures, eform_per_atom, "Eform", converter=converter, initial=0.0, final=5.0, num_centers=100, width=0.5 +) +# separate the dataset into training, validation and test data +train_data, val_data, test_data = split_dataset( + mp_dataset, + frac_list=[0.8, 0.1, 0.1], + shuffle=True, + random_state=42, +) +``` + + 100%|███████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 765.89it/s] + + + +```python +train_loader, val_loader, test_loader = MGLDataLoader( + train_data=train_data, + val_data=val_data, + test_data=test_data, + collate_fn=collate_fn, + batch_size=2, + num_workers=1, +) +``` + +In the next step, we setup the model and the ModelLightningModule. + + +```python +# get the average and standard deviation from the training set +# setup the embedding layer for node attributes +node_embed = torch.nn.Embedding(len(elem_list), 16) +# define the bond expansion +bond_expansion = BondExpansion(rbf_type="Gaussian", initial=0.0, final=5.0, num_centers=100, width=0.5) + +# setup the architecture of MEGNet model +model = MEGNet( + dim_node_embedding=16, + dim_edge_embedding=100, + dim_state_embedding=2, + nblocks=3, + hidden_layer_sizes_input=(64, 32), + hidden_layer_sizes_conv=(64, 64, 32), + nlayers_set2set=1, + niters_set2set=2, + hidden_layer_sizes_output=(32, 16), + is_classification=False, + activation_type="softplus2", + bond_expansion=bond_expansion, + cutoff=4.0, + gauss_width=0.5, +) + +# setup the MEGNetTrainer +lit_module = ModelLightningModule(model=model) +``` + +Finally, we will initialize the Pytorch Lightning trainer and run the fitting. Note that the max_epochs is set at 2 to demonstrate the fitting on a laptop. A real fitting should use max_epochs > 100 and be run in parallel on GPU resources. For the formation energy, it should be around 2000. The `accelerator="cpu"` was set just to ensure compatibility with M1 Macs. In a real world use case, please remove the kwarg or set it to cuda for GPU based training. + + +```python +trainer = pl.Trainer(max_epochs=100, accelerator="cpu") +trainer.fit(model=lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader) +``` + + GPU available: True (mps), used: False + TPU available: False, using: 0 TPU cores + IPU available: False, using: 0 IPUs + HPU available: False, using: 0 HPUs + Missing logger folder: /Users/shyue/repos/matgl/examples/lightning_logs + + | Name | Type | Params + -------------------------------------------- + 0 | model | MEGNet | 189 K + 1 | mae | MeanAbsoluteError | 0 + 2 | rmse | MeanSquaredError | 0 + -------------------------------------------- + 189 K Trainable params + 100 Non-trainable params + 189 K Total params + 0.758 Total estimated model params size (MB) + + + + Sanity Checking: 0it [00:00, ?it/s] + + + + Training: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + + Validation: 0it [00:00, ?it/s] + + + `Trainer.fit` stopped: `max_epochs=100` reached. + + + +```python +# This code just performs cleanup for this notebook. + +for fn in ("dgl_graph.bin", "dgl_line_graph.bin", "state_attr.pt", "labels.json"): + try: + os.remove(fn) + except FileNotFoundError: + pass + +shutil.rmtree("lightning_logs") +``` diff --git a/setup.py b/setup.py index 9fe1879a..25b1f87b 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ setup( name="matgl", - version="0.6.0", + version="0.6.1", author="Tsz Wai Ko, Marcel Nassar, Ji Qi, Santiago Miret, Eliott Liu, Shyue Ping Ong", author_email="t1ko@ucsd.edu, ongsp@ucsd.edu", maintainer="Shyue Ping Ong",