diff --git a/matgl/cli.py b/matgl/cli.py index 265ebbdd..57b00c1e 100644 --- a/matgl/cli.py +++ b/matgl/cli.py @@ -8,6 +8,7 @@ import warnings import numpy as np +import torch from pymatgen.core.structure import Structure from pymatgen.ext.matproj import MPRester @@ -73,10 +74,19 @@ def predict_structure(args): """ model = matgl.load_model(args.model) if args.infile: - for f in args.infile: - structure = Structure.from_file(f) - val = model.predict_structure(structure) - print(f"{args.model} prediction for {f}: {val}.") + if args.model == "MEGNet-MP-2019.4.1-BandGap-mfi": + state_dict = ["PBE", "GLLB-SC", "HSE", "SCAN"] + for count, f in enumerate(args.infile): + s = args.state_attr[count] # Get the corresponding state attribute + structure = Structure.from_file(f) + val = model.predict_structure(structure, torch.tensor(int(s))) + print(f"{args.model} prediction for {f} with {state_dict[int(s)]} bandgap: {val} eV.") + + else: + for f in args.infile: + structure = Structure.from_file(f) + val = model.predict_structure(structure) + print(f"{args.model} prediction for {f}: {val} eV/atom.") if args.mpids: mpr = MPRester() for mid in args.mpids: @@ -171,6 +181,14 @@ def main(): help="Input files containing structure. Any format supported by pymatgen's Structure.from_file method.", ) + p_predict.add_argument( + "-s", + "--state", + dest="state_attr", + nargs="+", + help="state attributes containing label. This should be an integer.", + ) + p_predict.add_argument( "-m", "--model", diff --git a/tests/test_cli.py b/tests/test_cli.py index 34cc41a6..35acf63e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -13,6 +13,8 @@ def test_entrypoint(Mo): assert os.path.exists("Mo_hello.cif") exit_status = os.system("mgl relax -i Mo.cif") assert exit_status == 0 + exit_status = os.system("mgl predict -i Mo.cif -s 1 -m MEGNet-MP-2019.4.1-BandGap-mfi") + assert exit_status == 0 exit_status = os.system("mgl predict -i Mo.cif -m MEGNet-MP-2018.6.1-Eform") assert exit_status == 0 exit_status = os.system("mgl predict -p mp-19017 -m MEGNet-MP-2018.6.1-Eform")