Skip to content

Commit

Permalink
Extend the predict functionality for multi-fidelity MEGNet model usin…
Browse files Browse the repository at this point in the history
…g mgl command (#123)

* Update the better M3GNet-DIRECT-PES

* extend the mgl predict for multi-fidelity MEGNet bandgap model

* Unit test for predict bandgap using cli command line is added

* -s option is added for state_attr
  • Loading branch information
kenko911 authored Aug 7, 2023
1 parent 1f280f2 commit 2c92c48
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
26 changes: 22 additions & 4 deletions matgl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings

import numpy as np
import torch
from pymatgen.core.structure import Structure
from pymatgen.ext.matproj import MPRester

Expand Down Expand Up @@ -73,10 +74,19 @@ def predict_structure(args):
"""
model = matgl.load_model(args.model)
if args.infile:
for f in args.infile:
structure = Structure.from_file(f)
val = model.predict_structure(structure)
print(f"{args.model} prediction for {f}: {val}.")
if args.model == "MEGNet-MP-2019.4.1-BandGap-mfi":
state_dict = ["PBE", "GLLB-SC", "HSE", "SCAN"]
for count, f in enumerate(args.infile):
s = args.state_attr[count] # Get the corresponding state attribute
structure = Structure.from_file(f)
val = model.predict_structure(structure, torch.tensor(int(s)))
print(f"{args.model} prediction for {f} with {state_dict[int(s)]} bandgap: {val} eV.")

else:
for f in args.infile:
structure = Structure.from_file(f)
val = model.predict_structure(structure)
print(f"{args.model} prediction for {f}: {val} eV/atom.")
if args.mpids:
mpr = MPRester()
for mid in args.mpids:
Expand Down Expand Up @@ -171,6 +181,14 @@ def main():
help="Input files containing structure. Any format supported by pymatgen's Structure.from_file method.",
)

p_predict.add_argument(
"-s",
"--state",
dest="state_attr",
nargs="+",
help="state attributes containing label. This should be an integer.",
)

p_predict.add_argument(
"-m",
"--model",
Expand Down
2 changes: 2 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def test_entrypoint(Mo):
assert os.path.exists("Mo_hello.cif")
exit_status = os.system("mgl relax -i Mo.cif")
assert exit_status == 0
exit_status = os.system("mgl predict -i Mo.cif -s 1 -m MEGNet-MP-2019.4.1-BandGap-mfi")
assert exit_status == 0
exit_status = os.system("mgl predict -i Mo.cif -m MEGNet-MP-2018.6.1-Eform")
assert exit_status == 0
exit_status = os.system("mgl predict -p mp-19017 -m MEGNet-MP-2018.6.1-Eform")
Expand Down

0 comments on commit 2c92c48

Please sign in to comment.