diff --git a/matgl/cli.py b/matgl/cli.py index 265ebbdd..6778120f 100644 --- a/matgl/cli.py +++ b/matgl/cli.py @@ -8,6 +8,7 @@ import warnings import numpy as np +import torch from pymatgen.core.structure import Structure from pymatgen.ext.matproj import MPRester @@ -73,10 +74,27 @@ def predict_structure(args): """ model = matgl.load_model(args.model) if args.infile: - for f in args.infile: - structure = Structure.from_file(f) - val = model.predict_structure(structure) - print(f"{args.model} prediction for {f}: {val}.") + if args.model == "MEGNet-MP-2019.4.1-BandGap-mfi": + state_dict = ["PBE", "GLLB-SC", "HSE", "SCAN"] + if len(args.infile) == 1 or not args.infile[1].isdigit(): + print( + "Error: the multi-fidelity MEGNet bandgap model requires the first argument" + "as structure file and the second arguments as state label (int) iteratively!!" + ) + exit() + for count in range(0, len(args.infile), 2): + structure = Structure.from_file(args.infile[count]) + val = model.predict_structure(structure, torch.tensor(int(args.infile[count + 1]))) + print( + f"{args.model} prediction for {args.infile[count]} with " + f"{state_dict[int(args.infile[count + 1])]} bandgap: {val} eV." + ) + + else: + for f in args.infile: + structure = Structure.from_file(f) + val = model.predict_structure(structure) + print(f"{args.model} prediction for {f}: {val} eV/atom.") if args.mpids: mpr = MPRester() for mid in args.mpids: