From 4d733be02b63dcfa2e83876d4ad6c597664d0ab2 Mon Sep 17 00:00:00 2001 From: Masaki Watanabe Date: Tue, 6 Jun 2023 13:07:48 +0900 Subject: [PATCH] add comparison with mncore --- export_static_onnx.py | 54 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 10 deletions(-) diff --git a/export_static_onnx.py b/export_static_onnx.py index c34574c..82bb257 100644 --- a/export_static_onnx.py +++ b/export_static_onnx.py @@ -8,8 +8,8 @@ from torch_dftd.dftd3_xc_params import get_dftd3_default_params -import pfvm.onnx -pfvm.onnx.register_custom_operators() +#import pfvm.onnx +#pfvm.onnx.register_custom_operators() import pytorch_pfn_extras.onnx as ppe_onnx from torch_dftd_static.nn.dftd3_module import DFTD3ModuleStatic @@ -67,7 +67,7 @@ def __init__( self.params, cutoff=cutoff, cnthr=cnthr, - dtype=torch.float32, + dtype=torch.float64, cutoff_smoothing=cutoff_smoothing, ) self.damping = damping @@ -92,6 +92,7 @@ def parse_args(): parser.add_argument("--out_dir", type=str, help="onnx output dir", required=True) parser.add_argument("--pad_num_atoms", type=int, help="num_atoms after padding", required=False) parser.add_argument("--pad_num_cells", type=int, help="num_cells after padding", required=False) + parser.add_argument("--compare_with", type=str, help="device to compare with (`pfvm` or `mncore`)", default=None) return parser.parse_args() def prepare_data(args): @@ -120,6 +121,7 @@ def prepare_data(args): else: cell = np.eye(3) cell_volume = np.abs(np.linalg.det(cell)) + cell_volume = torch.tensor(cell_volume) shift_vecs = calc_shift_vecs(cell, pbc, cutoff=cutoff) shift_vecs = torch.tensor(shift_vecs) @@ -144,10 +146,10 @@ def prepare_data(args): print("n_atoms = ", len(Z), "n_cell = ", len(shift_vecs), file=sys.stderr) print("atoms = ", atoms, file=sys.stderr) - args = { + inputs = { "Z": Z, - "pos": pos.type(torch.float32), - "shift_vecs": shift_vecs.type(torch.float32), + "pos": pos.type(torch.float64), + "shift_vecs": shift_vecs.type(torch.float64), "cell_volume": cell_volume, "atom_mask": atom_mask, "shift_mask": shift_mask, @@ -156,13 +158,45 @@ def prepare_data(args): exporter = ExportONNX(cutoff=cutoff, damping="bj") # compare energy with original implementation - energy = float(exporter.forward(**args)) - calc_orig = TorchDFTD3Calculator(atoms=atoms, device="cpu", damping="bj", cutoff=cutoff) + energy = float(exporter.forward(**inputs)) + calc_orig = TorchDFTD3Calculator(atoms=atoms, device="cpu", damping="bj", cutoff=cutoff, dtype=torch.float64) energy_orig = float(atoms.get_potential_energy()) print("energy = ", energy, "eV") print("energy_orig = ", energy_orig, "eV") assert abs(energy - energy_orig) < 1e-7 * abs(energy_orig) print("out_dir = ", out_dir, file=sys.stderr) - ppe_onnx.export_testcase(exporter, tuple(args.values()), out_dir, verbose=True, - input_names=["Z","pos","shift_vecs","cell_volume","atom_mask","shift_mask"]) \ No newline at end of file + ppe_onnx.export_testcase(exporter, tuple(inputs.values()), out_dir, verbose=True, + input_names=["Z","pos","shift_vecs","cell_volume","atom_mask","shift_mask"]) + + if args.compare_with is not None: + from codegen.utils import codegen_tempfile, storage + from mncore.mndevice import get_device + from mncore.runtime_core._context import Context, context + from mncore.runtime_core._registry import Registry + device = get_device(args.compare_with) + mncore_context = Context(device, Registry()) + Context.switch_context(mncore_context) + mncore_context.registry.register("model_dftd3", exporter) + options = { + "pfvm_compatible": True, + "float_dtype": "double", + "codegen_dir": "tmp/", + "save_onnx": True, + } + print("compile options: ", options) + with codegen_tempfile.TemporaryDirectoryWithPID() as tmpdir: + dftd3_on_mncore, _ = mncore_context.compile( + "dftd3", + lambda kwargs: exporter(**kwargs), + [], + storage.path(tmpdir), + inputs, + options, + ) + result = dftd3_on_mncore(inputs) + energy_mncore = float(result["result"]) + print("energy (my impl.) = ", energy, "eV") + print("energy (original impl.) = ", energy_orig, "eV") + print("energy (my impl. on device) = ", energy_mncore, "eV") + assert abs(energy_orig - energy_mncore) < 1e-6 * abs(energy_orig)