Skip to content

Commit

Permalink
extend the mgl predict for multi-fidelity MEGNet bandgap model
Browse files Browse the repository at this point in the history
  • Loading branch information
kenko911 committed Aug 7, 2023
1 parent 6ce9a43 commit d47d28c
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions matgl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings

import numpy as np
import torch
from pymatgen.core.structure import Structure
from pymatgen.ext.matproj import MPRester

Expand Down Expand Up @@ -73,10 +74,27 @@ def predict_structure(args):
"""
model = matgl.load_model(args.model)
if args.infile:
for f in args.infile:
structure = Structure.from_file(f)
val = model.predict_structure(structure)
print(f"{args.model} prediction for {f}: {val}.")
if args.model == "MEGNet-MP-2019.4.1-BandGap-mfi":
state_dict = ["PBE", "GLLB-SC", "HSE", "SCAN"]
if len(args.infile) == 1 or not args.infile[1].isdigit():
print(

Check warning on line 80 in matgl/cli.py

View check run for this annotation

Codecov / codecov/patch

matgl/cli.py#L78-L80

Added lines #L78 - L80 were not covered by tests
"Error: the multi-fidelity MEGNet bandgap model requires the first argument"
"as structure file and the second arguments as state label (int) iteratively!!"
)
exit()
for count in range(0, len(args.infile), 2):
structure = Structure.from_file(args.infile[count])
val = model.predict_structure(structure, torch.tensor(int(args.infile[count + 1])))
print(

Check warning on line 88 in matgl/cli.py

View check run for this annotation

Codecov / codecov/patch

matgl/cli.py#L84-L88

Added lines #L84 - L88 were not covered by tests
f"{args.model} prediction for {args.infile[count]} with "
f"{state_dict[int(args.infile[count + 1])]} bandgap: {val} eV."
)

else:
for f in args.infile:
structure = Structure.from_file(f)
val = model.predict_structure(structure)
print(f"{args.model} prediction for {f}: {val} eV/atom.")
if args.mpids:
mpr = MPRester()
for mid in args.mpids:
Expand Down

0 comments on commit d47d28c

Please sign in to comment.