Skip to content

Commit

Permalink
Fix the jupyter-notebook for M3GNet training
Browse files Browse the repository at this point in the history
  • Loading branch information
kenko911 committed Apr 1, 2024
1 parent d90b536 commit 47c5021
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,13 @@
"elem_list = get_element_list(structures)\n",
"# setup a graph converter\n",
"converter = Structure2Graph(element_types=elem_list, cutoff=4.0)\n",
"# convert the raw dataset into MEGNetDataset\n",
"# convert the raw dataset into M3GNetDataset\n",
"mp_dataset = MGLDataset(\n",
" threebody_cutoff=4.0, structures=structures, converter=converter, labels={\"eform\": eform_per_atom}\n",
" threebody_cutoff=4.0,\n",
" structures=structures,\n",
" converter=converter,\n",
" labels={\"eform\": eform_per_atom},\n",
" include_line_graph=True,\n",
")"
]
},
Expand Down Expand Up @@ -170,7 +174,7 @@
"source": [
"# Model setup\n",
"\n",
"In the next step, we setup the model and the ModelLightningModule. Here, we have initialized a MEGNet model from scratch. Alternatively, you can also load one of the pre-trained models for transfer learning, which may speed up the training."
"In the next step, we setup the model and the ModelLightningModule. Here, we have initialized a M3GNet model from scratch. Alternatively, you can also load one of the pre-trained models for transfer learning, which may speed up the training."
]
},
{
Expand All @@ -180,14 +184,14 @@
"metadata": {},
"outputs": [],
"source": [
"# setup the architecture of MEGNet model\n",
"# setup the architecture of M3GNet model\n",
"model = M3GNet(\n",
" element_types=elem_list,\n",
" is_intensive=True,\n",
" readout_type=\"set2set\",\n",
")\n",
"# setup the MEGNetTrainer\n",
"lit_module = ModelLightningModule(model=model)"
"# setup the M3GNetTrainer\n",
"lit_module = ModelLightningModule(model=model, include_line_graph=True)"
]
},
{
Expand Down Expand Up @@ -221,7 +225,7 @@
"source": [
"# Visualizing the convergence\n",
"\n",
"Finally, we can plot the convergence plot for the loss metrics. You can see that the MAE is already going down nicely with 20 epochs. Obviously, this is nowhere state of the art performance for the formation energies, but a longer training time should lead to results consistent with what was reported in the original MEGNet work."
"Finally, we can plot the convergence plot for the loss metrics. You can see that the MAE is already going down nicely with 20 epochs. Obviously, this is nowhere state of the art performance for the formation energies, but a longer training time should lead to results consistent with what was reported in the original M3GNet work."
]
},
{
Expand Down Expand Up @@ -273,7 +277,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.9.18"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"\n",
"import numpy as np\n",
"import pytorch_lightning as pl\n",
"from functools import partial\n",
"from dgl.data.utils import split_dataset\n",
"from mp_api.client import MPRester\n",
"from pytorch_lightning.loggers import CSVLogger\n",
Expand Down Expand Up @@ -126,30 +127,28 @@
"element_types = get_element_list(structures)\n",
"converter = Structure2Graph(element_types=element_types, cutoff=5.0)\n",
"dataset = MGLDataset(\n",
" threebody_cutoff=4.0,\n",
" structures=structures,\n",
" converter=converter,\n",
" labels=labels,\n",
" threebody_cutoff=4.0, structures=structures, converter=converter, labels=labels, include_line_graph=True\n",
")\n",
"train_data, val_data, test_data = split_dataset(\n",
" dataset,\n",
" frac_list=[0.8, 0.1, 0.1],\n",
" shuffle=True,\n",
" random_state=42,\n",
")\n",
"my_collate_fn = partial(collate_fn_efs, include_line_graph=True)\n",
"train_loader, val_loader, test_loader = MGLDataLoader(\n",
" train_data=train_data,\n",
" val_data=val_data,\n",
" test_data=test_data,\n",
" collate_fn=collate_fn_efs,\n",
" collate_fn=my_collate_fn,\n",
" batch_size=2,\n",
" num_workers=0,\n",
")\n",
"model = M3GNet(\n",
" element_types=element_types,\n",
" is_intensive=False,\n",
")\n",
"lit_module = PotentialLightningModule(model=model)"
"lit_module = PotentialLightningModule(model=model, include_line_graph=True)"
]
},
{
Expand Down Expand Up @@ -268,7 +267,7 @@
"# download a pre-trained M3GNet\n",
"m3gnet_nnp = matgl.load_model(\"M3GNet-MP-2021.2.8-PES\")\n",
"model_pretrained = m3gnet_nnp.model\n",
"lit_module_finetune = PotentialLightningModule(model=model_pretrained, lr=1e-4)"
"lit_module_finetune = PotentialLightningModule(model=model_pretrained, lr=1e-4, include_line_graph=True)"
]
},
{
Expand Down Expand Up @@ -384,7 +383,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.9.18"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 47c5021

Please sign in to comment.