Skip to content

Commit

Permalink
Improve MEGNet training example with plots.
Browse files Browse the repository at this point in the history
  • Loading branch information
Shyue Ping Ong committed Jun 21, 2023
1 parent 773493c commit 9b0a4e8
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 1,453 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import numpy as np
import pytorch_lightning as pl
from dgl.data.utils import split_dataset
from pymatgen.ext.matproj import MPRester
from pytorch_lightning.loggers import CSVLogger

from matgl.ext.pymatgen import Structure2Graph, get_element_list
from matgl.graph.data import M3GNetDataset, MGLDataLoader, collate_fn_efs
Expand Down Expand Up @@ -94,8 +95,8 @@ Finally, we will initialize the Pytorch Lightning trainer and run the fitting. H

```python
# If you wish to disable GPU or MPS (M1 mac) training, use the accelerator="cpu" kwarg.

trainer = pl.Trainer(max_epochs=2, accelerator="cpu")
logger = CSVLogger("logs", name="M3GNet_training")
trainer = pl.Trainer(max_epochs=2, accelerator="cpu", logger=logger)
trainer.fit(model=lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader)
```

Expand Down
Loading

0 comments on commit 9b0a4e8

Please sign in to comment.