Skip to content

Commit

Permalink
Merge pull request pfnet-research#15 from mwata/compare_with_mncore
Browse files Browse the repository at this point in the history
add comparison with mncore
  • Loading branch information
masakiwatanabe authored and GitHub Enterprise committed Jun 6, 2023
2 parents a437a3a + 4d733be commit f04e88c
Showing 1 changed file with 44 additions and 10 deletions.
54 changes: 44 additions & 10 deletions export_static_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from torch_dftd.dftd3_xc_params import get_dftd3_default_params

import pfvm.onnx
pfvm.onnx.register_custom_operators()
#import pfvm.onnx
#pfvm.onnx.register_custom_operators()

import pytorch_pfn_extras.onnx as ppe_onnx
from torch_dftd_static.nn.dftd3_module import DFTD3ModuleStatic
Expand Down Expand Up @@ -67,7 +67,7 @@ def __init__(
self.params,
cutoff=cutoff,
cnthr=cnthr,
dtype=torch.float32,
dtype=torch.float64,
cutoff_smoothing=cutoff_smoothing,
)
self.damping = damping
Expand All @@ -92,6 +92,7 @@ def parse_args():
parser.add_argument("--out_dir", type=str, help="onnx output dir", required=True)
parser.add_argument("--pad_num_atoms", type=int, help="num_atoms after padding", required=False)
parser.add_argument("--pad_num_cells", type=int, help="num_cells after padding", required=False)
parser.add_argument("--compare_with", type=str, help="device to compare with (`pfvm` or `mncore`)", default=None)
return parser.parse_args()

def prepare_data(args):
Expand Down Expand Up @@ -120,6 +121,7 @@ def prepare_data(args):
else:
cell = np.eye(3)
cell_volume = np.abs(np.linalg.det(cell))
cell_volume = torch.tensor(cell_volume)

shift_vecs = calc_shift_vecs(cell, pbc, cutoff=cutoff)
shift_vecs = torch.tensor(shift_vecs)
Expand All @@ -144,10 +146,10 @@ def prepare_data(args):
print("n_atoms = ", len(Z), "n_cell = ", len(shift_vecs), file=sys.stderr)
print("atoms = ", atoms, file=sys.stderr)

args = {
inputs = {
"Z": Z,
"pos": pos.type(torch.float32),
"shift_vecs": shift_vecs.type(torch.float32),
"pos": pos.type(torch.float64),
"shift_vecs": shift_vecs.type(torch.float64),
"cell_volume": cell_volume,
"atom_mask": atom_mask,
"shift_mask": shift_mask,
Expand All @@ -156,13 +158,45 @@ def prepare_data(args):
exporter = ExportONNX(cutoff=cutoff, damping="bj")

# compare energy with original implementation
energy = float(exporter.forward(**args))
calc_orig = TorchDFTD3Calculator(atoms=atoms, device="cpu", damping="bj", cutoff=cutoff)
energy = float(exporter.forward(**inputs))
calc_orig = TorchDFTD3Calculator(atoms=atoms, device="cpu", damping="bj", cutoff=cutoff, dtype=torch.float64)
energy_orig = float(atoms.get_potential_energy())
print("energy = ", energy, "eV")
print("energy_orig = ", energy_orig, "eV")
assert abs(energy - energy_orig) < 1e-7 * abs(energy_orig)

print("out_dir = ", out_dir, file=sys.stderr)
ppe_onnx.export_testcase(exporter, tuple(args.values()), out_dir, verbose=True,
input_names=["Z","pos","shift_vecs","cell_volume","atom_mask","shift_mask"])
ppe_onnx.export_testcase(exporter, tuple(inputs.values()), out_dir, verbose=True,
input_names=["Z","pos","shift_vecs","cell_volume","atom_mask","shift_mask"])

if args.compare_with is not None:
from codegen.utils import codegen_tempfile, storage
from mncore.mndevice import get_device
from mncore.runtime_core._context import Context, context
from mncore.runtime_core._registry import Registry
device = get_device(args.compare_with)
mncore_context = Context(device, Registry())
Context.switch_context(mncore_context)
mncore_context.registry.register("model_dftd3", exporter)
options = {
"pfvm_compatible": True,
"float_dtype": "double",
"codegen_dir": "tmp/",
"save_onnx": True,
}
print("compile options: ", options)
with codegen_tempfile.TemporaryDirectoryWithPID() as tmpdir:
dftd3_on_mncore, _ = mncore_context.compile(
"dftd3",
lambda kwargs: exporter(**kwargs),
[],
storage.path(tmpdir),
inputs,
options,
)
result = dftd3_on_mncore(inputs)
energy_mncore = float(result["result"])
print("energy (my impl.) = ", energy, "eV")
print("energy (original impl.) = ", energy_orig, "eV")
print("energy (my impl. on device) = ", energy_mncore, "eV")
assert abs(energy_orig - energy_mncore) < 1e-6 * abs(energy_orig)

0 comments on commit f04e88c

Please sign in to comment.